https://github.com/halide/Halide
Raw File
Tip revision: 1006c4e5eb7eb391155e847e1d9839fe9df1568f authored by Andrew Adams on 31 May 2023, 21:40:28 UTC
Fix operator/ on ModulusRemainder
Tip revision: 1006c4e
HexagonOptimize.cpp
#include "HexagonOptimize.h"
#include "Bounds.h"
#include "CSE.h"
#include "CodeGen_Internal.h"
#include "ConciseCasts.h"
#include "ExprUsesVar.h"
#include "FindIntrinsics.h"
#include "HexagonAlignment.h"
#include "IREquality.h"
#include "IRMatch.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Lerp.h"
#include "OptimizeShuffles.h"
#include "Scope.h"
#include "Simplify.h"
#include "Substitute.h"
#include <unordered_map>
#include <utility>

namespace Halide {
namespace Internal {

using std::pair;
using std::set;
using std::string;
using std::vector;

using namespace Halide::ConciseCasts;

Expr native_interleave(const Expr &x) {
    string fn;
    switch (x.type().bits()) {
    case 8:
        fn = "halide.hexagon.interleave.vb";
        break;
    case 16:
        fn = "halide.hexagon.interleave.vh";
        break;
    case 32:
        fn = "halide.hexagon.interleave.vw";
        break;
    default:
        internal_error << "Cannot interleave native vectors of type " << x.type() << "\n";
    }
    return Call::make(x.type(), fn, {x}, Call::PureExtern);
}

Expr native_deinterleave(const Expr &x) {
    string fn;
    switch (x.type().bits()) {
    case 8:
        fn = "halide.hexagon.deinterleave.vb";
        break;
    case 16:
        fn = "halide.hexagon.deinterleave.vh";
        break;
    case 32:
        fn = "halide.hexagon.deinterleave.vw";
        break;
    default:
        internal_error << "Cannot deinterleave native vectors of type " << x.type() << "\n";
    }
    return Call::make(x.type(), fn, {x}, Call::PureExtern);
}

bool is_native_interleave_op(const Expr &x, const char *name) {
    const Call *c = x.as<Call>();
    if (!c || c->args.size() != 1) {
        return false;
    }
    return starts_with(c->name, name);
}

bool is_native_interleave(const Expr &x) {
    return is_native_interleave_op(x, "halide.hexagon.interleave");
}

bool is_native_deinterleave(const Expr &x) {
    return is_native_interleave_op(x, "halide.hexagon.deinterleave");
}

string type_suffix(Type type, bool signed_variants) {
    string prefix = type.is_vector() ? ".v" : ".";
    if (type.is_int() || !signed_variants) {
        switch (type.bits()) {
        case 8:
            return prefix + "b";
        case 16:
            return prefix + "h";
        case 32:
            return prefix + "w";
        }
    } else if (type.is_uint()) {
        switch (type.bits()) {
        case 8:
            return prefix + "ub";
        case 16:
            return prefix + "uh";
        case 32:
            return prefix + "uw";
        }
    }
    internal_error << "Unsupported HVX type: " << type << "\n";
    return "";
}

string type_suffix(const Expr &a, bool signed_variants) {
    return type_suffix(a.type(), signed_variants);
}

string type_suffix(const Expr &a, const Expr &b, bool signed_variants) {
    return type_suffix(a, signed_variants) + type_suffix(b, signed_variants);
}

string type_suffix(const vector<Expr> &ops, bool signed_variants) {
    if (ops.empty()) {
        return "";
    }
    string suffix = type_suffix(ops.front(), signed_variants);
    for (size_t i = 1; i < ops.size(); i++) {
        suffix = suffix + type_suffix(ops[i], signed_variants);
    }
    return suffix;
}

namespace {

// Helper to handle various forms of multiplication.
Expr as_mul(const Expr &a) {
    if (a.as<Mul>()) {
        return a;
    } else if (const Call *wm = Call::as_intrinsic(a, {Call::widening_mul})) {
        return simplify(Mul::make(cast(wm->type, wm->args[0]), cast(wm->type, wm->args[1])));
    } else if (const Call *s = Call::as_intrinsic(a, {Call::shift_left, Call::widening_shift_left})) {
        const uint64_t *log2_b = as_const_uint(s->args[1]);
        if (log2_b) {
            Expr b = make_one(s->type) << cast(UInt(s->type.bits()), (int)*log2_b);
            return simplify(Mul::make(cast(s->type, s->args[0]), b));
        }
    } else if (const Call *wm = Call::as_intrinsic(a, {Call::widen_right_mul})) {
        return simplify(Mul::make(wm->args[0], cast(wm->type, wm->args[1])));
    }
    return Expr();
}

// Helpers to generate horizontally reducing multiply operations.
Expr halide_hexagon_add_2mpy(Type result_type, const string &suffix, Expr v0, Expr v1, Expr c0, Expr c1) {
    Expr call = Call::make(result_type, "halide.hexagon.add_2mpy" + suffix,
                           {std::move(v0), std::move(v1), std::move(c0), std::move(c1)}, Call::PureExtern);
    return native_interleave(call);
}

Expr halide_hexagon_add_2mpy(Type result_type, const string &suffix, Expr v01, Expr c01) {
    return Call::make(result_type, "halide.hexagon.add_2mpy" + suffix,
                      {std::move(v01), std::move(c01)}, Call::PureExtern);
}

Expr halide_hexagon_add_3mpy(Type result_type, const string &suffix, Expr v01, Expr c01) {
    return Call::make(result_type, "halide.hexagon.add_3mpy" + suffix,
                      {std::move(v01), std::move(c01)}, Call::PureExtern);
}

Expr halide_hexagon_add_4mpy(Type result_type, const string &suffix, Expr v01, Expr c01) {
    return Call::make(result_type, "halide.hexagon.add_4mpy" + suffix,
                      {std::move(v01), std::move(c01)}, Call::PureExtern);
}

struct Pattern {
    enum Flags {
        InterleaveResult = 1 << 0,  // After evaluating the pattern, interleave native vectors of the result.
        SwapOps01 = 1 << 1,         // Swap operands 0 and 1 prior to substitution.
        SwapOps12 = 1 << 2,         // Swap operands 1 and 2 prior to substitution.

        DeinterleaveOp0 = 1 << 5,  // Prior to evaluating the pattern, deinterleave native vectors of operand 0.
        DeinterleaveOp1 = 1 << 6,  // Same as above, but for operand 1.
        DeinterleaveOp2 = 1 << 7,
        DeinterleaveOps = DeinterleaveOp0 | DeinterleaveOp1 | DeinterleaveOp2,

        BeginDeinterleaveOp = 0,  // BeginDeinterleaveOp and EndDeinterleaveOp ensure that we check only three
        EndDeinterleaveOp = 3,    // deinterleave Op0, 1 and 2.
        // Many patterns are instructions that widen only
        // operand 0, which need to both deinterleave operand 0, and then
        // re-interleave the result.
        ReinterleaveOp0 = InterleaveResult | DeinterleaveOp0,

        v65orLater = 1 << 10,  // Pattern should be matched only for v65 target or later
        v66orLater = 1 << 11,  // Pattern should be matched only for v66 target or later
    };

    string intrin;  // Name of the intrinsic
    Expr pattern;   // The pattern to match against
    int flags;

    Pattern() = default;
    Pattern(const string &intrin, Expr p, int flags = 0)
        : intrin(intrin), pattern(std::move(p)), flags(flags) {
    }
};

Expr wild_u8 = Variable::make(UInt(8), "*");
Expr wild_u16 = Variable::make(UInt(16), "*");
Expr wild_u32 = Variable::make(UInt(32), "*");
Expr wild_u64 = Variable::make(UInt(64), "*");
Expr wild_i8 = Variable::make(Int(8), "*");
Expr wild_i16 = Variable::make(Int(16), "*");
Expr wild_i32 = Variable::make(Int(32), "*");
Expr wild_i64 = Variable::make(Int(64), "*");

Expr wild_u8x = Variable::make(Type(Type::UInt, 8, 0), "*");
Expr wild_u16x = Variable::make(Type(Type::UInt, 16, 0), "*");
Expr wild_u32x = Variable::make(Type(Type::UInt, 32, 0), "*");
Expr wild_u64x = Variable::make(Type(Type::UInt, 64, 0), "*");
Expr wild_i8x = Variable::make(Type(Type::Int, 8, 0), "*");
Expr wild_i16x = Variable::make(Type(Type::Int, 16, 0), "*");
Expr wild_i32x = Variable::make(Type(Type::Int, 32, 0), "*");
Expr wild_i64x = Variable::make(Type(Type::Int, 64, 0), "*");

// Check if a pattern with flags 'flags' is supported on the target.
bool check_pattern_target(int flags, const Target &target) {
    if ((flags & (Pattern::v65orLater)) &&
        !target.features_any_of({Target::HVX_v65, Target::HVX_v66})) {
        return false;
    }
    if ((flags & (Pattern::v66orLater)) &&
        !target.features_any_of({Target::HVX_v66})) {
        return false;
    }
    return true;
}

// Check if the matches satisfy the given pattern flags, and mutate the matches
// as specified by the flags.
bool process_match_flags(vector<Expr> &matches, int flags) {
    // The Pattern::Narrow*Op* flags are ordered such that the operand
    // corresponds to the bit (with operand 0 corresponding to the least
    // significant bit), so we can check for them all in a loop.
    for (const auto &match : matches) {
        if (!match.defined()) {
            return false;
        }
    }

    for (size_t i = Pattern::BeginDeinterleaveOp; i < Pattern::EndDeinterleaveOp; i++) {
        if (flags & (Pattern::DeinterleaveOp0 << (i - Pattern::BeginDeinterleaveOp))) {
            internal_assert(matches[i].type().is_vector());
            matches[i] = native_deinterleave(matches[i]);
        }
    }
    if (flags & Pattern::SwapOps01) {
        internal_assert(matches.size() >= 2);
        std::swap(matches[0], matches[1]);
    }
    if (flags & Pattern::SwapOps12) {
        internal_assert(matches.size() >= 3);
        std::swap(matches[1], matches[2]);
    }
    return true;
}

// Replace an expression with the one specified by a pattern.
Expr replace_pattern(Expr x, const vector<Expr> &matches, const Pattern &p) {
    x = Call::make(x.type(), p.intrin, matches, Call::PureExtern);
    if (p.flags & Pattern::InterleaveResult) {
        // The pattern wants us to interleave the result.
        x = native_interleave(x);
    }
    return x;
}

bool is_double_vector(const Expr &x, const Target &target) {
    int native_vector_lanes = target.natural_vector_size(x.type());
    return x.type().lanes() % (2 * native_vector_lanes) == 0;
}

// Attempt to apply one of the patterns to x. If a match is
// successful, the expression is replaced with a call using the
// matched operands. Prior to substitution, the matches are mutated
// with op_mutator.
Expr apply_patterns(Expr x, const vector<Pattern> &patterns, const Target &target, IRMutator *op_mutator) {
    constexpr int debug_level = 3;
    debug(debug_level) << "apply_patterns " << x << "\n";
    vector<Expr> matches;
    for (const Pattern &p : patterns) {
        if (!check_pattern_target(p.flags, target)) {
            continue;
        }

        if (expr_match(p.pattern, x, matches)) {
            debug(debug_level) << "matched " << p.pattern << "\n";
            debug(debug_level) << "matches:\n";
            for (const Expr &i : matches) {
                debug(debug_level) << i << "\n";
            }

            if (!process_match_flags(matches, p.flags)) {
                continue;
            }

            // Don't apply pattern if it involves an interleave,
            // and is not a multiple of two vectors.
            // See https://github.com/halide/Halide/issues/1582
            if ((p.flags & Pattern::InterleaveResult) && !is_double_vector(x, target)) {
                continue;
            }
            // Mutate the operands with the given mutator.
            for (Expr &op : matches) {
                op = op_mutator->mutate(op);
            }

            x = replace_pattern(x, matches, p);
            debug(debug_level) << "rewrote to: " << x << "\n";
            return x;
        }
    }
    return x;
}

template<typename T>
Expr apply_commutative_patterns(const T *op, const vector<Pattern> &patterns, const Target &target, IRMutator *mutator) {
    Expr ret = apply_patterns(op, patterns, target, mutator);
    if (!ret.same_as(op)) {
        return ret;
    }

    // Try commuting the op
    Expr commuted = T::make(op->b, op->a);
    ret = apply_patterns(commuted, patterns, target, mutator);
    if (!ret.same_as(commuted)) {
        return ret;
    }

    return op;
}

typedef pair<Expr, Expr> MulExpr;

// If ty is scalar or a vector with different lanes count,
// and x is a vector, try to remove a broadcast or adjust
// the number of lanes in Broadcast or indices in a Shuffle
// to match the ty lanes before using lossless_cast on it.
Expr unbroadcast_lossless_cast(Type ty, Expr x) {
    if (x.type().is_vector()) {
        if (const Broadcast *bc = x.as<Broadcast>()) {
            if (ty.is_scalar()) {
                x = bc->value;
            } else {
                x = Broadcast::make(bc->value, ty.lanes());
            }
        }
        // Check if shuffle can be treated as a broadcast.
        if (const Shuffle *shuff = x.as<Shuffle>()) {
            int factor = x.type().lanes() / ty.lanes();
            if (shuff->is_broadcast() && shuff->broadcast_factor() % factor == 0) {
                x = Shuffle::make(shuff->vectors, std::vector<int>(shuff->indices.begin(),
                                                                   shuff->indices.begin() + ty.lanes()));
            }
        }
    }
    if (ty.lanes() != x.type().lanes()) {
        return Expr();
    }
    return lossless_cast(ty, x);
}

// Try to extract a list of multiplies of the form a_ty*b_ty added
// together, such that op is equivalent to the sum of the
// multiplies in 'mpys', added to 'rest'.
// Difference in mpys.size() - return indicates the number of
// expressions where we pretend the op to be multiplied by 1.
int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count,
                 vector<MulExpr> &mpys, Expr &rest) {
    if ((int)mpys.size() >= max_mpy_count) {
        rest = rest.defined() ? Add::make(rest, op) : op;
        return 0;
    }

    // If the add is also widening, remove the cast.
    int mpy_bits = std::max(a_ty.bits(), b_ty.bits()) * 2;
    Expr maybe_mul = op;
    if (op.type().bits() == mpy_bits * 2) {
        if (const Cast *cast = op.as<Cast>()) {
            if (cast->value.type().bits() == mpy_bits) {
                maybe_mul = cast->value;
            }
        }
    }
    maybe_mul = as_mul(maybe_mul);

    if (maybe_mul.defined()) {
        const Mul *mul = maybe_mul.as<Mul>();
        Expr a = unbroadcast_lossless_cast(a_ty, mul->a);
        Expr b = unbroadcast_lossless_cast(b_ty, mul->b);
        if (a.defined() && b.defined()) {
            mpys.emplace_back(a, b);
            return 1;
        } else {
            // Try to commute the op.
            a = unbroadcast_lossless_cast(a_ty, mul->b);
            b = unbroadcast_lossless_cast(b_ty, mul->a);
            if (a.defined() && b.defined()) {
                mpys.emplace_back(a, b);
                return 1;
            }
        }
    } else if (const Add *add = op.as<Add>()) {
        int mpy_count = 0;
        mpy_count += find_mpy_ops(add->a, a_ty, b_ty, max_mpy_count, mpys, rest);
        mpy_count += find_mpy_ops(add->b, a_ty, b_ty, max_mpy_count, mpys, rest);
        return mpy_count;
    } else if (const Call *add = Call::as_intrinsic(op, {Call::widening_add})) {
        int mpy_count = 0;
        mpy_count += find_mpy_ops(cast(op.type(), add->args[0]), a_ty, b_ty, max_mpy_count, mpys, rest);
        mpy_count += find_mpy_ops(cast(op.type(), add->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest);
        return mpy_count;
    } else if (const Call *wadd = Call::as_intrinsic(op, {Call::widen_right_add})) {
        int mpy_count = 0;
        mpy_count += find_mpy_ops(wadd->args[0], a_ty, b_ty, max_mpy_count, mpys, rest);
        mpy_count += find_mpy_ops(cast(op.type(), wadd->args[1]), a_ty, b_ty, max_mpy_count, mpys, rest);
        return mpy_count;
    }

    // Attempt to pretend this op is multiplied by 1.
    Expr as_a = unbroadcast_lossless_cast(a_ty, op);
    Expr as_b = unbroadcast_lossless_cast(b_ty, op);

    if (as_a.defined()) {
        mpys.emplace_back(as_a, make_one(b_ty));
    } else if (as_b.defined()) {
        mpys.emplace_back(make_one(a_ty), as_b);
    } else {
        rest = rest.defined() ? Add::make(rest, op) : op;
    }
    return 0;
}

// Perform peephole optimizations on the IR, adding appropriate
// interleave and deinterleave calls.
class OptimizePatterns : public IRMutator {
    using IRMutator::visit;

    Scope<Interval> bounds;
    const Target &target;

    // Interesting muls are handled as widen_right_mul().

    // We'll try to sort the mpys based my mpys.first.
    // But, for this all the mpy.first exprs should either be
    // all loads or all slice_vectors.
    static void sort_mpy_exprs(vector<MulExpr> &mpys) {
        struct LoadCompare {
            bool operator()(const MulExpr &m1, const MulExpr &m2) const {
                if (!m1.first.defined() || !m2.first.defined()) {
                    return false;
                }
                const Load *m1_load = m1.first.as<Load>();
                const Load *m2_load = m2.first.as<Load>();
                internal_assert(m1_load && m2_load);
                const Ramp *m1_ramp = m1_load->index.as<Ramp>();
                const Ramp *m2_ramp = m2_load->index.as<Ramp>();
                internal_assert(m1_ramp && m2_ramp);
                return can_prove(m1_ramp->base < m2_ramp->base);
            }
        };
        const Shuffle *first_shuffle = mpys[0].first.as<Shuffle>();
        if (first_shuffle) {
            for (MulExpr &m : mpys) {
                const Shuffle *shuffle = m.first.as<Shuffle>();
                if (!shuffle || !shuffle->is_slice()) {
                    return;
                }
            }
            std::stable_sort(mpys.begin(), mpys.end(),
                             [](const MulExpr &m1, const MulExpr &m2) {
                                 return m1.first.as<Shuffle>()->slice_begin() < m2.first.as<Shuffle>()->slice_begin();
                             });
            return;
        } else if (const Load *first_load = mpys[0].first.as<Load>()) {
            const Ramp *first_ramp = first_load->index.as<Ramp>();
            if (!first_ramp) {
                return;
            }
            for (MulExpr &m : mpys) {
                const Load *load = m.first.as<Load>();
                if (!load ||
                    load->name != first_load->name ||
                    !load->index.as<Ramp>()) {
                    return;
                }
            }
            std::stable_sort(mpys.begin(), mpys.end(), LoadCompare());
        }
    }

    // Look for adds in an Add expression. This is factored out of visit(const Add*) to
    // enable look in widening_adds too.
    Expr find_mpyadds(const Expr &op_add) {
        const Add *op = op_add.as<Add>();
        internal_assert(op);

        // vmpa, vdmpy, and vrmpy instructions are hard to match with
        // patterns, do it manually here.
        // Try to find vrmpy opportunities first, which consume 4 operands.
        if (op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
            int lanes = op->type.lanes();
            vector<MulExpr> mpys;
            Expr rest;
            string suffix;
            int mpy_count = 0;

            // Try to find a vector*scalar multiply first, which will
            // match a subset of the expressions that vector*vector
            // matches.
            if (op->type.is_uint()) {
                mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8), 4, mpys, rest);
                suffix = ".vub.ub";
            } else {
                mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 4, mpys, rest);
                suffix = ".vub.b";
            }

            if (mpy_count > 0 && mpys.size() == 4) {
                // It's possible that permuting the order of the
                // multiply operands can simplify the shuffle away.
                // So, give yourself a fighting chance by ordering the
                // mpys in the ascending order of their start lanes (if all
                // are slice_vectors) or in the ascending order of their
                // load indices if all are loads from the same buffer.
                sort_mpy_exprs(mpys);
                Expr a0123 = Shuffle::make_interleave({mpys[0].first, mpys[1].first, mpys[2].first, mpys[3].first});
                a0123 = simplify(a0123);

                // We can generate this op for 16 bits, but, it's only
                // faster to do so if the interleave simplifies away.
                if (op->type.bits() == 32 || !a0123.as<Shuffle>()) {
                    Expr b0123 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[2].second, mpys[3].second});
                    b0123 = simplify(b0123);
                    b0123 = reinterpret(Type(b0123.type().code(), 32, 1), b0123);
                    Expr new_expr = halide_hexagon_add_4mpy(op->type.with_bits(32), suffix, a0123, b0123);
                    if (op->type.bits() == 16) {
                        // It's actually safe to use this op on 16 bit
                        // results, we just need to narrow the
                        // result. Overflow can occur, but will still
                        // produce the same result thanks to 2's
                        // complement arithmetic.
                        new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern);
                    }
                    if (rest.defined()) {
                        new_expr = Add::make(new_expr, rest);
                    }
                    return mutate(new_expr);
                }
            }

            // Now try to match vector*vector vrmpy expressions.
            mpys.clear();
            rest = Expr();
            if (op->type.is_uint()) {
                mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8, lanes), 4, mpys, rest);
                suffix = ".vub.vub";
            } else {
                mpy_count = find_mpy_ops(op, Int(8, lanes), Int(8, lanes), 4, mpys, rest);
                suffix = ".vb.vb";
            }

            // TODO: suffix = ".vub.vb"
            if (mpy_count > 0 && mpys.size() == 4) {
                // It's possible that permuting the order of the
                // multiply operands can simplify the shuffle away.
                // So, give yourself a fighting chance by ordering the
                // mpys in the ascending order of their start lanes (if all
                // are slice_vectors) or in the ascending order of their
                // load indices if all are loads from the same buffer.
                sort_mpy_exprs(mpys);
                Expr a0123 = Shuffle::make_interleave({mpys[0].first, mpys[1].first, mpys[2].first, mpys[3].first});
                Expr b0123 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[2].second, mpys[3].second});
                a0123 = simplify(a0123);
                b0123 = simplify(b0123);
                // We can generate this op for 16 bits, but, it's only
                // faster to do so if the interleave simplifies away.
                if (op->type.bits() == 32 || (!a0123.as<Shuffle>() && !b0123.as<Shuffle>())) {
                    Expr new_expr = halide_hexagon_add_4mpy(op->type.with_bits(32), suffix, a0123, b0123);
                    if (op->type.bits() == 16) {
                        // It's actually safe to use this op on 16 bit
                        // results, we just need to narrow the
                        // result. Overflow can occur, but will still
                        // produce the same result thanks to 2's
                        // complement arithmetic.
                        new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern);
                    }
                    if (rest.defined()) {
                        new_expr = Add::make(new_expr, rest);
                    }
                    return mutate(new_expr);
                }
            }
        }

        // Find opportunities vdmpy or vmpa.
        if (op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
            int lanes = op->type.lanes();

            vector<MulExpr> mpys;
            Expr rest;
            string vmpa_suffix;
            string vdmpy_suffix;
            int mpy_count = 0;

            // Try to find vector*scalar multiplies.
            if (op->type.bits() == 16) {
                mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 2, mpys, rest);
                vmpa_suffix = ".vub.vub.b.b";
                vdmpy_suffix = ".vub.b";
            } else if (op->type.bits() == 32) {
                mpy_count = find_mpy_ops(op, Int(16, lanes), Int(8), 2, mpys, rest);
                vmpa_suffix = ".vh.vh.b.b";
                vdmpy_suffix = ".vh.b";
            }
            if (mpy_count > 0 && mpys.size() == 2) {
                // It's possible that permuting the order of the
                // multiply operands can simplify the shuffle away.
                // So, give yourself a fighting chance by ordering the
                // mpys in the ascending order of their start lanes (if all
                // are slice_vectors) or in the ascending order of their
                // load indices if all are loads from the same buffer.
                sort_mpy_exprs(mpys);
                Expr a01 = Shuffle::make_interleave({mpys[0].first, mpys[1].first});
                a01 = simplify(a01);
                // TODO: This requires the operands to be in a
                // particular order. It should be more robust... but
                // this is pretty tough to do, other than simply
                // trying all permutations.
                Expr new_expr;
                if (!a01.as<Shuffle>() || vmpa_suffix.empty()) {
                    Expr b01 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[0].second, mpys[1].second});
                    b01 = simplify(b01);
                    b01 = reinterpret(Type(b01.type().code(), 32, 1), b01);
                    new_expr = halide_hexagon_add_2mpy(op->type, vdmpy_suffix, a01, b01);
                } else {
                    new_expr = halide_hexagon_add_2mpy(op->type, vmpa_suffix, mpys[0].first, mpys[1].first, mpys[0].second, mpys[1].second);
                }
                if (rest.defined()) {
                    new_expr = Add::make(new_expr, rest);
                }
                return mutate(new_expr);
            }
        }
        return Expr();
    }

    Expr visit(const Add *op) override {
        Expr mpyadd = find_mpyadds(op);
        if (mpyadd.defined()) {
            return mpyadd;
        }
        static const vector<Pattern> adds = {
            // Use accumulating versions of vmpa, vdmpy, vrmpy instructions when possible.
            {"halide.hexagon.acc_add_2mpy.vh.vub.vub.b.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.vub.b.b", wild_u8x, wild_u8x, wild_i8, wild_i8), Pattern::ReinterleaveOp0},
            {"halide.hexagon.acc_add_2mpy.vw.vh.vh.b.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.vh.b.b", wild_i16x, wild_i16x, wild_i8, wild_i8), Pattern::ReinterleaveOp0},
            {"halide.hexagon.acc_add_2mpy.vh.vub.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.b", wild_u8x, wild_i32)},
            {"halide.hexagon.acc_add_2mpy.vw.vh.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.b", wild_i16x, wild_i32)},
            {"halide.hexagon.acc_add_3mpy.vh.vub.b", wild_i16x + native_interleave(halide_hexagon_add_3mpy(Int(16, 0), ".vub.b", wild_u8x, wild_i32)), Pattern::ReinterleaveOp0},
            {"halide.hexagon.acc_add_3mpy.vh.vb.b", wild_i16x + native_interleave(halide_hexagon_add_3mpy(Int(16, 0), ".vb.b", wild_i8x, wild_i32)), Pattern::ReinterleaveOp0},
            {"halide.hexagon.acc_add_3mpy.vw.vh.b", wild_i32x + native_interleave(halide_hexagon_add_3mpy(Int(32, 0), ".vh.b", wild_i16x, wild_i32)), Pattern::ReinterleaveOp0},
            {"halide.hexagon.acc_add_4mpy.vw.vub.b", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.b", wild_u8x, wild_i32)},
            {"halide.hexagon.acc_add_4mpy.vuw.vub.ub", wild_u32x + halide_hexagon_add_4mpy(UInt(32, 0), ".vub.ub", wild_u8x, wild_u32)},
            {"halide.hexagon.acc_add_4mpy.vuw.vub.ub", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.ub", wild_u8x, wild_u32)},
            {"halide.hexagon.acc_add_4mpy.vuw.vub.vub", wild_u32x + halide_hexagon_add_4mpy(UInt(32, 0), ".vub.vub", wild_u8x, wild_u8x)},
            {"halide.hexagon.acc_add_4mpy.vuw.vub.vub", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.vub", wild_u8x, wild_u8x)},
            {"halide.hexagon.acc_add_4mpy.vw.vub.vb", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.vb", wild_u8x, wild_i8x)},
            {"halide.hexagon.acc_add_4mpy.vw.vb.vb", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vb.vb", wild_i8x, wild_i8x)},

            // Widening multiply-accumulates with a scalar.
            {"halide.hexagon.add_mpy.vuh.vub.ub", wild_u16x + widening_mul(wild_u8x, wild_u8), Pattern::ReinterleaveOp0},
            {"halide.hexagon.add_mpy.vh.vub.b", wild_i16x + widening_mul(wild_u8x, wild_i8), Pattern::ReinterleaveOp0},
            {"halide.hexagon.add_mpy.vuw.vuh.uh", wild_u32x + widening_mul(wild_u16x, wild_u16), Pattern::ReinterleaveOp0},

            // These patterns aren't exactly right because the instruction
            // saturates the result. However, this is really the instruction
            // that we want to use in most cases, and we can exploit the fact
            // that 32 bit signed arithmetic overflow is undefined to argue
            // that these patterns are not completely incorrect.
            {"halide.hexagon.satw_add_mpy.vw.vh.h", wild_i32x + widening_mul(wild_i16x, wild_i16), Pattern::ReinterleaveOp0},

            // Widening multiply-accumulates.
            {"halide.hexagon.add_mpy.vuh.vub.vub", wild_u16x + widening_mul(wild_u8x, wild_u8x), Pattern::ReinterleaveOp0},
            {"halide.hexagon.add_mpy.vuw.vuh.vuh", wild_u32x + widening_mul(wild_u16x, wild_u16x), Pattern::ReinterleaveOp0},
            {"halide.hexagon.add_mpy.vh.vb.vb", wild_i16x + widening_mul(wild_i8x, wild_i8x), Pattern::ReinterleaveOp0},
            {"halide.hexagon.add_mpy.vw.vh.vh", wild_i32x + widening_mul(wild_i16x, wild_i16x), Pattern::ReinterleaveOp0},

            {"halide.hexagon.add_mpy.vh.vub.vb", wild_i16x + widening_mul(wild_u8x, wild_i8x), Pattern::ReinterleaveOp0},
            {"halide.hexagon.add_mpy.vw.vh.vuh", wild_i32x + widening_mul(wild_i16x, wild_u16x), Pattern::ReinterleaveOp0},
            {"halide.hexagon.add_mpy.vh.vub.vb", wild_i16x + widening_mul(wild_i8x, wild_u8x), Pattern::ReinterleaveOp0 | Pattern::SwapOps12},
            {"halide.hexagon.add_mpy.vw.vh.vuh", wild_i32x + widening_mul(wild_u16x, wild_i16x), Pattern::ReinterleaveOp0 | Pattern::SwapOps12},

            // Shift-accumulates.
            {"halide.hexagon.add_shr.vw.vw.uw", wild_i32x + (wild_i32x >> wild_u32)},
            {"halide.hexagon.add_shl.vw.vw.uw", wild_i32x + (wild_i32x << wild_u32)},
            {"halide.hexagon.add_shl.vw.vw.uw", wild_u32x + (wild_u32x << wild_u32)},
            {"halide.hexagon.add_shl.vh.vh.uh", wild_i16x + (wild_i16x << wild_u16), Pattern::v65orLater},
            {"halide.hexagon.add_shl.vh.vh.uh", wild_u16x + (wild_u16x << wild_u16), Pattern::v65orLater},
            {"halide.hexagon.add_shr.vh.vh.uh", wild_i16x + (wild_i16x >> wild_u16), Pattern::v65orLater},
            {"halide.hexagon.add_shl.vh.vh.uh", wild_i16x + (wild_i16x << wild_i16), Pattern::v65orLater},
            {"halide.hexagon.add_shl.vh.vh.uh", wild_u16x + (wild_u16x << wild_u16), Pattern::v65orLater},

            // Non-widening multiply-accumulates with a scalar.
            {"halide.hexagon.add_mul.vh.vh.b", wild_i16x + widen_right_mul(wild_i16x, wild_i8)},
            {"halide.hexagon.add_mul.vw.vw.h", wild_i32x + widen_right_mul(wild_i32x, wild_i16)},
            // TODO: There's also a add_mul.vw.vw.b

            // This pattern is very general, so it must come last.
            {"halide.hexagon.add_mul.vh.vh.vh", wild_i16x + wild_i16x * wild_i16x},
        };

        if (op->type.is_vector()) {
            Expr new_expr = apply_commutative_patterns(op, adds, target, this);
            if (!new_expr.same_as(op)) {
                return new_expr;
            }
        }
        return IRMutator::visit(op);
    }

    Expr visit(const Sub *op) override {
        if (op->type.is_vector()) {
            // Try negating op->b, using an add pattern if successful.
            Expr neg_b = lossless_negate(op->b);
            if (neg_b.defined()) {
                return mutate(op->a + neg_b);
            }
        }
        return IRMutator::visit(op);
    }

    Expr visit(const Max *op) override {
        Expr expr = IRMutator::visit(op);

        if (op->type.is_vector()) {
            // This pattern is weird (two operands must match, result
            // needs 1 added) and we're unlikely to need another
            // pattern for max, so just match it directly.
            static const pair<string, Expr> cl[] = {
                {"halide.hexagon.cls.vh", max(count_leading_zeros(wild_i16x), count_leading_zeros(~wild_i16x))},
                {"halide.hexagon.cls.vw", max(count_leading_zeros(wild_i32x), count_leading_zeros(~wild_i32x))},
            };
            vector<Expr> matches;
            for (const auto &i : cl) {
                if (expr_match(i.second, expr, matches) && equal(matches[0], matches[1])) {
                    return Call::make(op->type, i.first, {matches[0]}, Call::PureExtern) + 1;
                }
            }
        }
        return expr;
    }

    Expr visit(const Cast *op) override {
        static const vector<Pattern> casts = {
            // Halving unsigned subtract.
            {"halide.hexagon.navg.vub.vub", i8(widening_sub(wild_u8x, wild_u8x) >> 1)},

            // Narrowing casts. These may interleave later with trunclo.
            {"halide.hexagon.packhi.vh", u8(wild_u16x >> 8)},
            {"halide.hexagon.packhi.vh", u8(wild_i16x >> 8)},
            {"halide.hexagon.packhi.vh", i8(wild_u16x >> 8)},
            {"halide.hexagon.packhi.vh", i8(wild_i16x >> 8)},
            {"halide.hexagon.packhi.vw", u16(wild_u32x >> 16)},
            {"halide.hexagon.packhi.vw", u16(wild_i32x >> 16)},
            {"halide.hexagon.packhi.vw", i16(wild_u32x >> 16)},
            {"halide.hexagon.packhi.vw", i16(wild_i32x >> 16)},

            // Narrowing with shifting.
            {"halide.hexagon.trunc_shr.vw.uw", i16(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0},

            // Narrowing casts. These may interleave later with trunc.
            {"halide.hexagon.pack.vh", u8(wild_u16x)},
            {"halide.hexagon.pack.vh", u8(wild_i16x)},
            {"halide.hexagon.pack.vh", i8(wild_u16x)},
            {"halide.hexagon.pack.vh", i8(wild_i16x)},
            {"halide.hexagon.pack.vw", u16(wild_u32x)},
            {"halide.hexagon.pack.vw", u16(wild_i32x)},
            {"halide.hexagon.pack.vw", i16(wild_u32x)},
            {"halide.hexagon.pack.vw", i16(wild_i32x)},

            // Widening casts
            {"halide.hexagon.zxt.vub", u16(wild_u8x), Pattern::InterleaveResult},
            {"halide.hexagon.zxt.vub", i16(wild_u8x), Pattern::InterleaveResult},
            {"halide.hexagon.zxt.vuh", u32(wild_u16x), Pattern::InterleaveResult},
            {"halide.hexagon.zxt.vuh", i32(wild_u16x), Pattern::InterleaveResult},
            {"halide.hexagon.sxt.vb", u16(wild_i8x), Pattern::InterleaveResult},
            {"halide.hexagon.sxt.vb", i16(wild_i8x), Pattern::InterleaveResult},
            {"halide.hexagon.sxt.vh", u32(wild_i16x), Pattern::InterleaveResult},
            {"halide.hexagon.sxt.vh", i32(wild_i16x), Pattern::InterleaveResult},
        };

        // To hit more of the patterns we want, rewrite "double casts"
        // as two stage casts. This also avoids letting vector casts
        // fall through to LLVM, which will generate large unoptimized
        // shuffles.
        static const vector<pair<Expr, Expr>> cast_rewrites = {
            // Narrowing
            {u8(wild_u32x), u8(u16(wild_u32x))},
            {u8(wild_i32x), u8(i16(wild_i32x))},
            {i8(wild_u32x), i8(u16(wild_u32x))},
            {i8(wild_i32x), i8(i16(wild_i32x))},

            // Widening
            {u32(wild_u8x), u32(u16(wild_u8x))},
            {u32(wild_i8x), u32(i16(wild_i8x))},
            {i32(wild_u8x), i32(u16(wild_u8x))},
            {i32(wild_i8x), i32(i16(wild_i8x))},
        };

        if (op->type.is_vector()) {
            Expr cast = op;

            Expr new_expr = apply_patterns(cast, casts, target, this);
            if (!new_expr.same_as(cast)) {
                return new_expr;
            }

            // If we didn't find a pattern, try using one of the
            // rewrites above.
            vector<Expr> matches;
            for (const auto &i : cast_rewrites) {
                if (expr_match(i.first, cast, matches)) {
                    Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes()));
                    debug(3) << "rewriting cast to: " << replacement << " from " << cast << "\n";
                    return mutate(replacement);
                }
            }
        }
        return IRMutator::visit(op);
    }

    Expr visit(const Call *op) override {
        if (op->is_intrinsic(Call::if_then_else) && op->args[0].type().is_vector()) {
            const Broadcast *b = op->args[0].as<Broadcast>();
            if (!b || b->value.type().is_vector()) {
                return op;
            }
        }
        if (op->is_intrinsic(Call::widening_add)) {
            Expr mpyadds = find_mpyadds(Add::make(cast(op->type, op->args[0]), cast(op->type, op->args[1])));
            if (mpyadds.defined()) {
                return mpyadds;
            }
        }
        // TODO: There can be better instruction selection for these.
        if (op->is_intrinsic(Call::widen_right_add)) {
            Expr lowered = Add::make(op->args[0], cast(op->type, op->args[1]));
            return mutate(lowered);
        } else if (op->is_intrinsic(Call::widen_right_sub)) {
            Expr lowered = Sub::make(op->args[0], cast(op->type, op->args[1]));
            return mutate(lowered);
        }

        // These intrinsics should get the default lowering, and we need to recursively mutate the
        // result. We don't want to let these fall through to CodeGen_Hexagon and CodeGen_LLVM,
        // because they might generate interleaeves or deinterleaves we can simplify.
        static const vector<Call::IntrinsicOp> default_lower = {
            // TODO: Maybe there are widening shift instructions on Hexagon?
            Call::widening_shift_left,
        };

        for (Call::IntrinsicOp i : default_lower) {
            if (op->is_intrinsic(i)) {
                return mutate(lower_intrinsic(op));
            }
        }

        static const vector<Pattern> calls = {
            // Non-widening scalar multiplication.
            {"halide.hexagon.mul.vh.b", widen_right_mul(wild_i16x, wild_i8)},
            {"halide.hexagon.mul.vw.h", widen_right_mul(wild_i32x, wild_i16)},
            // TODO: There's also mul.vw.b. We currently generate mul.vw.h
            // instead. I'm not sure mul.vw.b is faster, it might even be
            // slower due to the extra step in broadcasting the scalar up to
            // 32 bits.

            // One operand widening multiplication.
            {"halide.hexagon.mul.vw.vh", widen_right_mul(wild_i32x, wild_i16x), Pattern::ReinterleaveOp0},
            {"halide.hexagon.mul.vw.vuh", widen_right_mul(wild_u32x, wild_u16x), Pattern::ReinterleaveOp0},

            // Saturating narrowing casts with rounding
            {"halide.hexagon.trunc_satub_rnd.vh", u8_sat(rounding_shift_right(wild_i16x, 8)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satb_rnd.vh", i8_sat(rounding_shift_right(wild_i16x, 8)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satub_rnd.vuh", u8_sat(rounding_shift_right(wild_u16x, 8)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satuh_rnd.vw", u16_sat(rounding_shift_right(wild_i32x, 16)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_sath_rnd.vw", i16_sat(rounding_shift_right(wild_i32x, 16)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satuh_rnd.vuw", u16_sat(rounding_shift_right(wild_u32x, 16)), Pattern::DeinterleaveOp0},

            // Saturating narrowing casts with rounding
            {"halide.hexagon.trunc_satub_shr_rnd.vh", u8_sat(rounding_shift_right(wild_i16x, wild_u16)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satb_shr_rnd.vh", i8_sat(rounding_shift_right(wild_i16x, wild_u16)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satub_shr_rnd.vuh", u8_sat(rounding_shift_right(wild_u16x, wild_u16)), Pattern::DeinterleaveOp0 | Pattern::v65orLater},
            {"halide.hexagon.trunc_satuh_shr_rnd.vw", u16_sat(rounding_shift_right(wild_i32x, wild_u32)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_sath_shr_rnd.vw", i16_sat(rounding_shift_right(wild_i32x, wild_u32)), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satuh_shr_rnd.vuw", u16_sat(rounding_shift_right(wild_u32x, wild_u32)), Pattern::DeinterleaveOp0},

            // Saturating narrowing casts
            {"halide.hexagon.trunc_satub_shr.vh.uh", u8_sat(wild_i16x >> wild_u16), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_satuh_shr.vw.uw", u16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0},
            {"halide.hexagon.trunc_sath_shr.vw.uw", i16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0},

            // For some of the following narrowing casts, we have the choice of
            // non-interleaving or interleaving instructions. Because we don't
            // know which one we prefer during pattern matching, we match the
            // non-interleaving versions for now and replace them with the
            // instructions that interleave later if it makes sense.

            // Saturating narrowing casts. These may interleave later with trunc_sat.
            {"halide.hexagon.pack_satub.vh", u8_sat(wild_i16x)},
            {"halide.hexagon.pack_satuh.vw", u16_sat(wild_i32x)},
            {"halide.hexagon.pack_satb.vh", i8_sat(wild_i16x)},
            {"halide.hexagon.pack_sath.vw", i16_sat(wild_i32x)},

            // We don't have a vpack equivalent to this one, so we match it directly.
            {"halide.hexagon.trunc_satuh.vuw", u16_sat(wild_u32x), Pattern::DeinterleaveOp0},

            // Multiply keep high half.
            {"halide.hexagon.trunc_mpy.vw.vw", mul_shift_right(wild_i32x, wild_i32x, 32)},

            // Scalar multiply keep high half, with multiplication by 2.
            {"halide.hexagon.trunc_satw_mpy2.vh.h", mul_shift_right(wild_i16x, wild_i16, 15)},
            {"halide.hexagon.trunc_satw_mpy2.vh.h", mul_shift_right(wild_i16, wild_i16x, 15), Pattern::SwapOps01},
            {"halide.hexagon.trunc_satdw_mpy2.vw.vw", mul_shift_right(wild_i32x, wild_i32x, 31)},

            // Scalar and vector multiply keep high half, with multiplication by 2, and rounding.
            {"halide.hexagon.trunc_satw_mpy2_rnd.vh.h", rounding_mul_shift_right(wild_i16x, wild_i16, 15)},
            {"halide.hexagon.trunc_satw_mpy2_rnd.vh.h", rounding_mul_shift_right(wild_i16, wild_i16x, 15), Pattern::SwapOps01},
            {"halide.hexagon.trunc_satw_mpy2_rnd.vh.vh", rounding_mul_shift_right(wild_i16x, wild_i16x, 15)},
            {"halide.hexagon.trunc_satdw_mpy2_rnd.vw.vw", rounding_mul_shift_right(wild_i32x, wild_i32x, 31)},

            // Vector by scalar widening multiplies. These need to happen before the ones below, to avoid
            // using vector versions when scalar versions would suffice.
            {"halide.hexagon.mpy.vub.ub", widening_mul(wild_u8x, wild_u8), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vub.b", widening_mul(wild_u8x, wild_i8), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vuh.uh", widening_mul(wild_u16x, wild_u16), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vh.h", widening_mul(wild_i16x, wild_i16), Pattern::InterleaveResult},

            // These are calls that are almost trivial, but they differ due to interleaving.
            {"halide.hexagon.add_vuh.vub.vub", widening_add(wild_u8x, wild_u8x), Pattern::InterleaveResult},
            {"halide.hexagon.add_vuw.vuh.vuh", widening_add(wild_u16x, wild_u16x), Pattern::InterleaveResult},
            {"halide.hexagon.add_vw.vh.vh", widening_add(wild_i16x, wild_i16x), Pattern::InterleaveResult},
            {"halide.hexagon.sub_vh.vub.vub", widening_sub(wild_u8x, wild_u8x), Pattern::InterleaveResult},
            {"halide.hexagon.sub_vw.vuh.vuh", widening_sub(wild_u16x, wild_u16x), Pattern::InterleaveResult},
            {"halide.hexagon.sub_vw.vh.vh", widening_sub(wild_i16x, wild_i16x), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vub.vub", widening_mul(wild_u8x, wild_u8x), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vub.vb", widening_mul(wild_u8x, wild_i8x), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vub.vb", widening_mul(wild_i8x, wild_u8x), Pattern::InterleaveResult | Pattern::SwapOps01},
            {"halide.hexagon.mpy.vb.vb", widening_mul(wild_i8x, wild_i8x), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vuh.vuh", widening_mul(wild_u16x, wild_u16x), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vh.vh", widening_mul(wild_i16x, wild_i16x), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vh.vuh", widening_mul(wild_i16x, wild_u16x), Pattern::InterleaveResult},
            {"halide.hexagon.mpy.vh.vuh", widening_mul(wild_u16x, wild_i16x), Pattern::InterleaveResult | Pattern::SwapOps01},
        };

        // To hit more of the patterns we want, rewrite "double casts"
        // as two stage casts. This also avoids letting vector casts
        // fall through to LLVM, which will generate large unoptimized
        // shuffles.
        static const vector<pair<Expr, Expr>> cast_rewrites = {
            // Saturating narrowing
            {u8_sat(wild_u32x), u8_sat(u16_sat(wild_u32x))},
            {u8_sat(wild_i32x), u8_sat(i16_sat(wild_i32x))},
            {i8_sat(wild_u32x), i8_sat(u16_sat(wild_u32x))},
            {i8_sat(wild_i32x), i8_sat(i16_sat(wild_i32x))},
        };

        if (op->type.is_vector()) {
            Expr new_expr = apply_patterns(op, calls, target, this);
            if (!new_expr.same_as(op)) {
                return new_expr;
            }

            // If we didn't find a pattern, try using one of the
            // rewrites above.
            vector<Expr> matches;
            for (const auto &i : cast_rewrites) {
                if (expr_match(i.first, op, matches)) {
                    Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes()));
                    debug(3) << "rewriting cast to: " << replacement << " from " << Expr(op) << "\n";
                    return mutate(replacement);
                }
            }
        }

        if (op->is_intrinsic(Call::lerp)) {
            // We need to lower lerps now to optimize the arithmetic
            // that they generate.
            internal_assert(op->args.size() == 3);
            return mutate(lower_lerp(op->type, op->args[0], op->args[1], op->args[2], target));
        } else if ((op->is_intrinsic(Call::div_round_to_zero) ||
                    op->is_intrinsic(Call::mod_round_to_zero)) &&
                   !op->type.is_float() && op->type.is_vector()) {
            internal_assert(op->args.size() == 2);
            Expr a = op->args[0];
            Expr b = op->args[1];
            // Run bounds analysis to estimate the range of result.
            Expr abs_result = op->type.is_int() ? abs(a / b) : a / b;
            Expr extent_upper = find_constant_bound(abs_result, Direction::Upper, bounds);
            const uint64_t *upper_bound = as_const_uint(extent_upper);
            a = mutate(a);
            b = mutate(b);
            std::pair<Expr, Expr> div_mod = long_div_mod_round_to_zero(a, b, upper_bound);
            if (op->is_intrinsic(Call::div_round_to_zero)) {
                return div_mod.first;
            }
            return div_mod.second;
        } else if (op->is_intrinsic(Call::mul_shift_right) ||
                   op->is_intrinsic(Call::rounding_mul_shift_right)) {
            // Lower these now, we might be able to use other patterns on the result.
            return mutate(lower_intrinsic(op));
        } else {
            return IRMutator::visit(op);
        }
    }

    template<typename NodeType, typename T>
    NodeType visit_let(const T *op) {
        bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
        NodeType node = IRMutator::visit(op);
        bounds.pop(op->name);
        return node;
    }

    Expr visit(const Let *op) override {
        return visit_let<Expr>(op);
    }

    Stmt visit(const LetStmt *op) override {
        return visit_let<Stmt>(op);
    }

    Expr visit(const Div *op) override {
        if (!op->type.is_float() && op->type.is_vector()) {
            return mutate(lower_int_uint_div(op->a, op->b));
        }
        return IRMutator::visit(op);
    }

    Expr visit(const Mod *op) override {
        if (!op->type.is_float() && op->type.is_vector()) {
            return mutate(lower_int_uint_mod(op->a, op->b));
        }
        return IRMutator::visit(op);
    }

public:
    OptimizePatterns(const Target &t)
        : target(t) {
    }
};

class VectorReducePatterns : public IRMutator {
    using IRMutator::visit;

    // Check for interleaves of vectors with stride 1 like shuffle with indices:
    // 0, 1, 2,..., window_size - 1,
    // 1, 2, 3,..., window_size,
    // 2, 3, 4,..., window_size + 1,
    // .....
    // window_size != lanes
    // TODO: Their could be other patterns as well which we should match
    static int is_stencil_interleave(const Expr &op, int window_size) {
        int lanes = op.type().lanes();
        internal_assert(lanes > window_size);
        if (const Shuffle *shuff = op.as<Shuffle>()) {
            for (int i = window_size; i < lanes; i++) {
                if ((i % window_size != window_size - 1) &&
                    (shuff->indices[i - window_size + 1] != shuff->indices[i])) {
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    Expr visit(const Call *op) override {
        if (op->is_intrinsic(Call::if_then_else) && op->args[0].type().is_vector()) {
            const Broadcast *b = op->args[0].as<Broadcast>();
            if (!b || b->value.type().is_vector()) {
                return op;
            }
        }
        return IRMutator::visit(op);
    }

    Expr visit(const VectorReduce *op) override {
        if (!op->type.is_vector() || op->type.is_float() || op->op != VectorReduce::Add) {
            return IRMutator::visit(op);
        }

        struct Signature {
            enum Flags {
                SlidingWindow = 1,
                ScalarB = 1 << 1,
                NarrowB = 1 << 2,
                SwapOps = 1 << 3,  // Swapping ops is done before matching B to scalars.
            };
            int factor;
            int native_return_bits;
            Expr pattern;
            int flags;
        };

        int in_lanes = op->value.type().lanes();
        int out_lanes = op->type.lanes();
        int factor = in_lanes / out_lanes;

        // Map of instruction signatures
        // clang-format off
        static const vector<Signature> sigs = {
            // --------- vrmpy ---------
            // Sliding window
            {4, 32, widening_mul(wild_u8x, wild_u8x), Signature::SlidingWindow | Signature::ScalarB},
            {4, 32, widening_mul(wild_u8x, wild_i8x), Signature::SlidingWindow | Signature::ScalarB},
            {4, 32, widening_mul(wild_i8x, wild_u8x), Signature::SlidingWindow | Signature::ScalarB | Signature::SwapOps},
            // Vector * Scalar
            {4, 32, widening_mul(wild_u8x, wild_u8x), Signature::ScalarB},
            {4, 32, widening_mul(wild_i8x, wild_u8x), Signature::ScalarB},
            {4, 32, widening_mul(wild_u8x, wild_i8x), Signature::ScalarB},
            {4, 32, widening_mul(wild_i8x, wild_u8x), Signature::ScalarB | Signature::SwapOps},
            // Vector * Vector
            {4, 32, widening_mul(wild_u8x, wild_u8x)},
            {4, 32, widening_mul(wild_u8x, wild_i8x)},
            {4, 32, widening_mul(wild_i8x, wild_u8x), Signature::SwapOps},
            {4, 32, widening_mul(wild_i8x, wild_i8x)},
            // Sum
            {4, 32, wild_u8x, Signature::SlidingWindow},
            {4, 32, wild_i8x, Signature::SlidingWindow},
            {4, 32, wild_u8x},
            {4, 32, wild_i8x},

            // --------- vtmpy ---------
            // Vtmpy has additional requirement that third coefficient b[2]
            // needs to be 1.
            // Sliding window
            {3, 16, widening_mul(wild_i8x, wild_i8x), Signature::SlidingWindow | Signature::ScalarB},
            {3, 16, widening_mul(wild_u8x, wild_i8x), Signature::SlidingWindow | Signature::ScalarB},
            {3, 16, widening_mul(wild_i8x, wild_u8x), Signature::SlidingWindow | Signature::ScalarB | Signature::SwapOps},
            {3, 32, widening_mul(wild_i16x, wild_i16x), Signature::SlidingWindow | Signature::ScalarB},
            // Sum
            {3, 16, wild_i8x, Signature::SlidingWindow},
            {3, 16, wild_u8x, Signature::SlidingWindow},
            {3, 32, wild_i16x, Signature::SlidingWindow},

            // --------- vdmpy ---------
            // Sliding window
            {2, 16, widening_mul(wild_u8x, wild_i8x), Signature::SlidingWindow | Signature::ScalarB},
            {2, 16, widening_mul(wild_i8x, wild_u8x), Signature::SlidingWindow | Signature::ScalarB | Signature::SwapOps},
            {2, 32, widening_mul(wild_i16x, wild_i16x), Signature::SlidingWindow | Signature::ScalarB},
            // Vector * Scalar
            {2, 16, widening_mul(wild_u8x, wild_i8x), Signature::ScalarB},
            {2, 16, widening_mul(wild_i8x, wild_u8x), Signature::ScalarB | Signature::SwapOps},
            {2, 32, widening_mul(wild_i16x, wild_i16x), Signature::ScalarB | Signature::NarrowB},
            {2, 32, widening_mul(wild_i16x, wild_u16x), Signature::ScalarB},                       // Saturates
            {2, 32, widening_mul(wild_u16x, wild_i16x), Signature::ScalarB | Signature::SwapOps},  // Saturates
            {2, 32, widening_mul(wild_i16x, wild_i16x), Signature::ScalarB},                       // Saturates
            // Vector * Vector
            {2, 32, widening_mul(wild_i16x, wild_i16x)},  // Saturates
            // Sum
            {2, 16, wild_u8x, Signature::SlidingWindow},
            {2, 32, wild_i16x, Signature::SlidingWindow},
            {2, 16, wild_u8x},
            {2, 32, wild_i16x},
        };
        // clang-format on
        std::vector<Expr> matches;
        for (const Signature &sig : sigs) {
            if (factor != sig.factor) {
                continue;
            }
            // Try matching the pattern with any number of bits between the pattern type and the native result.
            for (int bits = sig.pattern.type().bits(); bits <= sig.native_return_bits; bits *= 2) {
                matches.clear();
                Expr pattern = sig.pattern;
                if (bits != pattern.type().bits()) {
                    // Allow the widening cast to cast to the type of the result, which may
                    // differ from the pattern.
                    pattern = Cast::make(op->type.with_bits(bits).with_lanes(0), pattern);
                }
                if (expr_match(pattern, op->value, matches)) {
                    break;
                }
            }
            if (matches.empty()) {
                continue;
            }

            Expr a = matches[0];
            Expr b = matches.size() > 1 ? matches[1] : make_const(Type(op->type.code(), 8, factor), 1);
            if (sig.flags & Signature::SwapOps) {
                std::swap(a, b);
            }

            if (sig.flags & Signature::ScalarB) {
                if (const Shuffle *shuff = b.as<Shuffle>()) {
                    if (shuff->is_broadcast() && shuff->broadcast_factor() % factor == 0) {
                        internal_assert(shuff->vectors.size() == 1);
                        b = Shuffle::make_slice(shuff->vectors[0], 0, 1, factor);
                    }
                } else if (const Shuffle *shuff = a.as<Shuffle>()) {
                    // If the types are equal, we can commute the ops.
                    if (a.type().element_of() == b.type().element_of() &&
                        shuff->is_broadcast() && shuff->broadcast_factor() % factor == 0) {
                        internal_assert(shuff->vectors.size() == 1);
                        a = Shuffle::make_slice(shuff->vectors[0], 0, 1, factor);
                        std::swap(a, b);
                    }
                }
                if (b.type().lanes() != factor) {
                    // This isn't a scalar, it doesn't match the pattern.
                    continue;
                }
            }

            if (sig.flags & Signature::NarrowB) {
                b = lossless_cast(b.type().narrow(), b);
                if (!b.defined()) {
                    continue;
                }
            }

            Expr a0, a1;
            if (sig.flags & Signature::SlidingWindow) {
                if (!is_stencil_interleave(a, factor)) {
                    continue;
                }
                // Split a into a0, a1 to get the correct vector args
                // for sliding window reduction instructions. Below are
                // required shuffle indices for a0 and a1:
                // For factor == 2:
                // If a  -> shuff[0, 1,...., out_lanes]
                //    a0 -> shuff[0, 1,...., out_lanes - 1]
                //    a1 -> shuff[2, 3,...., out_lanes + 1]
                //          Last index of a1 is don't care
                // For factor == 3:
                // If a  -> shuff[0, 1,...., out_lanes + 1]
                //    a0 -> shuff[0, 1,...., out_lanes - 1]
                //    a1 -> shuff[2, 3,...., out_lanes + 1]
                // For factor == 4:
                // If a  -> shuff[0, 1,...., out_lanes + 3]
                //    a0 -> shuff[0, 1,...., out_lanes - 1]
                //    a1 -> shuff[4, 5,...., out_lanes + 4]
                //          Last index of a1 is don't care
                // TODO: Why does this require a to be a shuffle? Why isn't this just:
                // a0 = Shuffle::make_slice(a, 0, factor, out_lanes);
                // a1 = Shuffle::make_slice(a, factor - 1, factor, out_lanes);
                // The current code probably also generates messier shuffles the backend
                // may not recognize.
                if (const Shuffle *shuff = a.as<Shuffle>()) {
                    vector<int> a0_indices(out_lanes), a1_indices(out_lanes);
                    for (int i = 0; i < out_lanes; i++) {
                        a0_indices[i] = shuff->indices[i * factor];
                        a1_indices[i] = shuff->indices[(i + 1) * factor - 1];
                    }
                    a0 = Shuffle::make(shuff->vectors, a0_indices);
                    a1 = Shuffle::make(shuff->vectors, a1_indices);
                    if (factor == 2 || factor == 4) {
                        // We'll need to rotate the indices by one element
                        // to get the correct order.
                        Type ty = UInt(8).with_lanes(a1.type().lanes() * a1.type().bytes());
                        a1 = reinterpret(a1.type(),
                                         Call::make(ty, "halide.hexagon.vror",
                                                    {reinterpret(ty, a1), a1.type().bytes()},
                                                    Call::PureExtern));
                    } else {
                        // Vtmpy has additional requirement that third
                        // coefficient b[2] needs to be 1.
                        if (!can_prove(Shuffle::make_extract_element(b, 2) == 1)) {
                            continue;
                        }
                        b = Shuffle::make_slice(b, 0, 1, 2);
                    }
                    a = Shuffle::make_concat({a0, a1});
                } else {
                    continue;
                }
            }

            std::string suffix = type_suffix(a);
            if (b.type().lanes() <= factor) {
                suffix += type_suffix(b.type().element_of());
                if (b.type().lanes() * b.type().bits() <= 16) {
                    b = Shuffle::make({b}, {0, 1, 0, 1});
                }
                // Reinterpret scalar b arg to get correct type.
                b = simplify(reinterpret(Type(b.type().code(), b.type().lanes() * b.type().bits(), 1), b));
            } else {
                suffix += type_suffix(b);
            }

            Type result_type = op->type.with_bits(sig.native_return_bits);

            Expr result;
            if (factor == 4) {
                if (sig.flags & Signature::SlidingWindow) {
                    result = halide_hexagon_add_4mpy(result_type, suffix + ".stencil", a, b);
                } else {
                    result = halide_hexagon_add_4mpy(result_type, suffix, a, b);
                }
            } else {
                if (sig.flags & Signature::SlidingWindow) {
                    string name = "halide.hexagon.add_" + std::to_string(factor) + "mpy" + suffix;
                    result = native_interleave(Call::make(result_type, name, {a, b}, Call::PureExtern));
                } else {
                    // factor == 3 has only sliding window reductions.
                    result = halide_hexagon_add_2mpy(result_type, suffix, a, b);
                }
            }
            if (result.type() != op->type) {
                result = Cast::make(op->type, result);
            }
            return result;
        }
        return IRMutator::visit(op);
    }
};

// Attempt to cancel out redundant interleave/deinterleave pairs. The
// basic strategy is to push interleavings toward the end of the
// program, using the fact that interleaves can pass through pointwise
// IR operations. When an interleave collides with a deinterleave,
// they cancel out.
class EliminateInterleaves : public IRMutator {
    Scope<bool> vars;

    // We need to know when loads are a multiple of 2 native vectors.
    int native_vector_bits;

    // Alignment analyzer for loads and stores
    HexagonAlignmentAnalyzer alignment_analyzer;

    // Check if x is an expression that is either an interleave, or
    // transitively is an interleave.
    bool yields_removable_interleave(const Expr &x) {
        if (is_native_interleave(x)) {
            return true;
        }

        if (const Let *let = x.as<Let>()) {
            return yields_removable_interleave(let->body);
        }

        const Variable *var = x.as<Variable>();
        if (var && vars.contains(var->name + ".deinterleaved")) {
            return true;
        }

        if (const Load *load = x.as<Load>()) {
            if (buffers.contains(load->name)) {
                return buffers.get(load->name) != BufferState::NotInterleaved;
            }
        }

        if (const Add *op = x.as<Add>()) {
            return yields_removable_interleave(op->a) || yields_removable_interleave(op->b);
        } else if (const Sub *op = x.as<Sub>()) {
            return yields_removable_interleave(op->a) || yields_removable_interleave(op->b);
        }

        return false;
    }

    // Check if x either has a removable interleave, or it can pretend
    // to be an interleave at no cost (a scalar or a broadcast).
    bool yields_interleave(const Expr &x) {
        if (yields_removable_interleave(x)) {
            return true;
        }

        // These yield an interleave, but we shouldn't
        // deinterleave them if we want to remove an actual
        // interleave.
        if (x.type().is_scalar() || x.as<Broadcast>()) {
            return true;
        }

        if (const Let *let = x.as<Let>()) {
            return yields_interleave(let->body);
        }

        // This is different from the deinterleaved lets handled in
        // yields_removable_interleave. These are lets that can be
        // deinterleaved freely, but are not actually interleaves.
        const Variable *var = x.as<Variable>();
        if (var && vars.contains(var->name + ".weak_deinterleaved")) {
            return true;
        }

        if (const Load *load = x.as<Load>()) {
            if (buffers.contains(load->name)) {
                return buffers.get(load->name) != BufferState::NotInterleaved;
            }
        }

        if (const Add *op = x.as<Add>()) {
            return yields_interleave(op->a) || yields_interleave(op->b);
        } else if (const Sub *op = x.as<Sub>()) {
            return yields_interleave(op->a) || yields_interleave(op->b);
        }

        return false;
    }

    // Check that if we were to remove interleaves from exprs, that
    // we would remove more interleaves than we added deinterleaves.
    bool yields_removable_interleave(const vector<Expr> &exprs) {
        int removable = 0;
        int does_not_yield = 0;
        for (const Expr &i : exprs) {
            if (yields_removable_interleave(i)) {
                removable++;
            } else if (!yields_interleave(i)) {
                does_not_yield++;
            }
        }
        return removable > 0 && removable > does_not_yield;
    }

    // Asserting that x is an expression that can yield an interleave
    // operation, return the expression being interleaved.
    Expr remove_interleave(Expr x) {
        if (is_native_interleave(x)) {
            return x.as<Call>()->args[0];
        } else if (x.type().is_scalar() || x.as<Broadcast>()) {
            return x;
        }

        if (const Variable *var = x.as<Variable>()) {
            if (vars.contains(var->name + ".deinterleaved")) {
                return Variable::make(var->type, var->name + ".deinterleaved");
            } else if (vars.contains(var->name + ".weak_deinterleaved")) {
                return Variable::make(var->type, var->name + ".weak_deinterleaved");
            }
        }

        if (const Let *let = x.as<Let>()) {
            Expr body = remove_interleave(let->body);
            if (!body.same_as(let->body)) {
                return Let::make(let->name, let->value, body);
            } else {
                return x;
            }
        }

        if (const Load *load = x.as<Load>()) {
            if (buffers.contains(load->name)) {
                BufferState &state = buffers.ref(load->name);
                if (state != BufferState::NotInterleaved) {
                    state = BufferState::Interleaved;
                    return x;
                }
            }
        }

        if (const Add *op = x.as<Add>()) {
            return Add::make(remove_interleave(op->a), remove_interleave(op->b));
        } else if (const Sub *op = x.as<Sub>()) {
            return Sub::make(remove_interleave(op->a), remove_interleave(op->b));
        }

        // If we rewrite x as interleave(deinterleave(x)), we can remove the interleave.
        return native_deinterleave(x);
    }

    template<typename T>
    Expr visit_binary(const T *op) {
        Expr expr;
        Expr a = mutate(op->a);
        Expr b = mutate(op->b);
        if (yields_removable_interleave({a, b})) {
            expr = T::make(remove_interleave(a), remove_interleave(b));
            expr = native_interleave(expr);
        } else if (!a.same_as(op->a) || !b.same_as(op->b)) {
            expr = T::make(a, b);
        } else {
            expr = op;
        }
        return expr;
    }

    Expr visit(const Add *op) override {
        return visit_binary(op);
    }
    Expr visit(const Sub *op) override {
        return visit_binary(op);
    }
    Expr visit(const Mul *op) override {
        return visit_binary(op);
    }
    Expr visit(const Div *op) override {
        return visit_binary(op);
    }
    Expr visit(const Mod *op) override {
        return visit_binary(op);
    }
    Expr visit(const Min *op) override {
        return visit_binary(op);
    }
    Expr visit(const Max *op) override {
        return visit_binary(op);
    }

    Expr visit(const Select *op) override {
        Expr true_value = mutate(op->true_value);
        Expr false_value = mutate(op->false_value);
        Expr cond = mutate(op->condition);

        // The condition isn't a vector, so we can just check if we
        // should move an interleave from the true/false values.
        if (cond.type().is_scalar() && yields_removable_interleave({true_value, false_value})) {
            true_value = remove_interleave(true_value);
            false_value = remove_interleave(false_value);
            return native_interleave(Select::make(cond, true_value, false_value));
        } else if (!cond.same_as(op->condition) ||
                   !true_value.same_as(op->true_value) ||
                   !false_value.same_as(op->false_value)) {
            return Select::make(cond, true_value, false_value);
        } else {
            return op;
        }
    }

    template<typename NodeType, typename LetType>
    NodeType visit_let(const LetType *op) {

        Expr value = mutate(op->value);
        string deinterleaved_name;
        NodeType body;
        // Other code in this mutator needs to be able to tell the
        // difference between a Let that yields a deinterleave, and a
        // let that has a removable deinterleave. Lets that can
        // pretend to be deinterleaved at no cost are given an
        // alternative let labelled "weak_deinterleaved", while lets
        // that have a removable interleave are given an alternative
        // let labelled "deinterleaved".
        if (yields_removable_interleave(value)) {
            // We can provide a deinterleaved version of this let value.
            deinterleaved_name = op->name + ".deinterleaved";
            vars.push(deinterleaved_name, true);
            body = mutate(op->body);
            vars.pop(deinterleaved_name);
        } else if (yields_interleave(value)) {
            // We have a soft deinterleaved version of this let value.
            deinterleaved_name = op->name + ".weak_deinterleaved";
            vars.push(deinterleaved_name, true);
            body = mutate(op->body);
            vars.pop(deinterleaved_name);
        } else {
            body = mutate(op->body);
        }

        if (value.same_as(op->value) && body.same_as(op->body)) {
            return op;
        } else if (body.same_as(op->body)) {
            // If the body didn't change, we must not have used the deinterleaved value.
            return LetType::make(op->name, value, body);
        } else {
            // We need to rewrap the body with new lets.
            NodeType result = body;
            bool deinterleaved_used = stmt_or_expr_uses_var(result, deinterleaved_name);
            bool interleaved_used = stmt_or_expr_uses_var(result, op->name);
            if (deinterleaved_used && interleaved_used) {
                // The body uses both the interleaved and
                // deinterleaved version of this let. Generate both
                // lets, using the deinterleaved one to generate the
                // interleaved one.
                Expr deinterleaved = remove_interleave(value);

                // If we actually removed an interleave from the
                // value, re-interleave it to get the interleaved let
                // value.
                Expr interleaved = Variable::make(deinterleaved.type(), deinterleaved_name);
                if (!deinterleaved.same_as(value)) {
                    interleaved = native_interleave(interleaved);
                }

                result = LetType::make(op->name, interleaved, result);
                return LetType::make(deinterleaved_name, deinterleaved, result);
            } else if (deinterleaved_used) {
                // Only the deinterleaved value is used, we can eliminate the interleave.
                return LetType::make(deinterleaved_name, remove_interleave(value), result);
            } else if (interleaved_used) {
                // Only the original value is used, regenerate the let.
                return LetType::make(op->name, value, result);
            } else {
                // The let must have been dead.
                internal_assert(!stmt_or_expr_uses_var(op->body, op->name))
                    << "EliminateInterleaves eliminated a non-dead let.\n";
                return op->body;
            }
        }
    }

    Expr visit(const Let *op) override {
        Expr expr = visit_let<Expr>(op);

        // Lift interleaves out of Let expression bodies.
        const Let *let = expr.as<Let>();
        if (let && yields_removable_interleave(let->body)) {
            expr = native_interleave(Let::make(let->name, let->value, remove_interleave(let->body)));
        }
        return expr;
    }

    Stmt visit(const LetStmt *op) override {
        return visit_let<Stmt>(op);
    }

    Expr visit(const Cast *op) override {
        if (op->type.bits() == op->value.type().bits()) {
            // We can only move interleaves through casts of the same size.
            Expr value = mutate(op->value);

            if (yields_removable_interleave(value)) {
                value = remove_interleave(value);
                return native_interleave(Cast::make(op->type, value));
            } else if (!value.same_as(op->value)) {
                return Cast::make(op->type, value);
            } else {
                return op;
            }
        } else {
            return IRMutator::visit(op);
        }
    }

    static bool is_interleavable(const Call *op) {
        // These calls can have interleaves moved from operands to the
        // result...
        static const set<string> interleavable = {
            Call::get_intrinsic_name(Call::bitwise_and),
            Call::get_intrinsic_name(Call::bitwise_not),
            Call::get_intrinsic_name(Call::bitwise_xor),
            Call::get_intrinsic_name(Call::bitwise_or),
            Call::get_intrinsic_name(Call::shift_left),
            Call::get_intrinsic_name(Call::shift_right),
            Call::get_intrinsic_name(Call::abs),
            Call::get_intrinsic_name(Call::absd)};
        if (interleavable.count(op->name) != 0) {
            return true;
        }

        // ...these calls cannot. Furthermore, these calls have the
        // same return type as the arguments, which means our test
        // below will be inaccurate.
        static const set<string> not_interleavable = {
            "halide.hexagon.interleave.vb",
            "halide.hexagon.interleave.vh",
            "halide.hexagon.interleave.vw",
            "halide.hexagon.deinterleave.vb",
            "halide.hexagon.deinterleave.vh",
            "halide.hexagon.deinterleave.vw",
            Call::get_intrinsic_name(Call::hvx_gather),
            Call::get_intrinsic_name(Call::hvx_scatter),
            Call::get_intrinsic_name(Call::hvx_scatter_acc),
        };
        if (not_interleavable.count(op->name) != 0) {
            return false;
        }

        if (starts_with(op->name, "halide.hexagon.")) {
            // We assume that any hexagon intrinsic is interleavable
            // as long as all of the vector operands have the same
            // number of lanes and lane width as the return type.
            for (const Expr &i : op->args) {
                if (i.type().is_scalar()) {
                    continue;
                }
                if (i.type().bits() != op->type.bits() || i.type().lanes() != op->type.lanes()) {
                    return false;
                }
            }
        }
        return true;
    }

    Expr visit(const Call *op) override {
        vector<Expr> args(op->args);

        // mutate all the args.
        bool changed = false;
        for (Expr &i : args) {
            Expr new_i = mutate(i);
            changed = changed || !new_i.same_as(i);
            i = new_i;
        }

        // For a few operations, we have a choice of several
        // instructions, an interleaving or a non-inerleaving
        // variant. We handle this by generating the instruction that
        // does not deinterleave, and then opportunistically select
        // the interleaving alternative when we can cancel out to the
        // interleave.
        static std::map<string, string> deinterleaving_alts = {
            {"halide.hexagon.pack.vh", "halide.hexagon.trunc.vh"},
            {"halide.hexagon.pack.vw", "halide.hexagon.trunc.vw"},
            {"halide.hexagon.packhi.vh", "halide.hexagon.trunclo.vh"},
            {"halide.hexagon.packhi.vw", "halide.hexagon.trunclo.vw"},
            {"halide.hexagon.pack_satub.vh", "halide.hexagon.trunc_satub.vh"},
            {"halide.hexagon.pack_sath.vw", "halide.hexagon.trunc_sath.vw"},
            {"halide.hexagon.pack_satuh.vw", "halide.hexagon.trunc_satuh.vw"},
        };

        // The reverse mapping of the above.
        static std::map<string, string> interleaving_alts = {
            {"halide.hexagon.trunc.vh", "halide.hexagon.pack.vh"},
            {"halide.hexagon.trunc.vw", "halide.hexagon.pack.vw"},
            {"halide.hexagon.trunclo.vh", "halide.hexagon.packhi.vh"},
            {"halide.hexagon.trunclo.vw", "halide.hexagon.packhi.vw"},
            {"halide.hexagon.trunc_satub.vh", "halide.hexagon.pack_satub.vh"},
            {"halide.hexagon.trunc_sath.vw", "halide.hexagon.pack_sath.vw"},
            {"halide.hexagon.trunc_satuh.vw", "halide.hexagon.pack_satuh.vw"},
        };

        if (is_native_deinterleave(op) && yields_interleave(args[0])) {
            // This is a deinterleave of an interleave! Remove them both.
            return remove_interleave(args[0]);
        } else if (is_interleavable(op) && yields_removable_interleave(args)) {
            // We can reduce the total number of interleave and deinterleave
            // operations by removing interleaves from the arguments.
            for (Expr &i : args) {
                i = remove_interleave(i);
            }
            Expr expr = Call::make(op->type, op->name, args, op->call_type,
                                   op->func, op->value_index, op->image, op->param);
            // Add the interleave back to the result of the call.
            return native_interleave(expr);
        } else if (deinterleaving_alts.find(op->name) != deinterleaving_alts.end() &&
                   yields_removable_interleave(args)) {
            // This call has a deinterleaving alternative, and the
            // arguments are interleaved, so we should use the
            // alternative instead.
            for (Expr &i : args) {
                i = remove_interleave(i);
            }
            return Call::make(op->type, deinterleaving_alts[op->name], args, op->call_type);
        } else if (interleaving_alts.count(op->name) && is_native_deinterleave(args[0])) {
            // This is an interleaving alternative with a
            // deinterleave, which can be generated when we
            // deinterleave storage. Revert back to the interleaving
            // op so we can remove the deinterleave.
            Expr arg = args[0].as<Call>()->args[0];
            return Call::make(op->type, interleaving_alts[op->name], {arg}, op->call_type,
                              op->func, op->value_index, op->image, op->param);
        } else if (changed) {
            return Call::make(op->type, op->name, args, op->call_type,
                              op->func, op->value_index, op->image, op->param);
        } else {
            return op;
        }
    }

    // Track whether buffers are interleaved or not.
    enum class BufferState {
        Unknown,         // We don't know if this buffer is interleaved or not.
        Interleaved,     // We know the buffer is interleaved.
        NotInterleaved,  // We know the buffer is not interleaved.
    };
    Scope<BufferState> buffers;

    // False for buffers that have any loads or stores that are unaligned
    Scope<bool> aligned_buffer_access;

    // Buffers we should deinterleave the storage of.
    Scope<bool> deinterleave_buffers;

    Stmt visit(const Allocate *op) override {
        Expr condition = mutate(op->condition);

        // First, we need to mutate the op, to pull native interleaves
        // down, and to gather information about the loads and stores.
        buffers.push(op->name, BufferState::Unknown);

        // Assume buffers are accessed by aligned loads and stores by default.
        aligned_buffer_access.push(op->name, true);

        Stmt body = mutate(op->body);
        bool deinterleave = (buffers.get(op->name) == BufferState::Interleaved) &&
                            (aligned_buffer_access.get(op->name) == true);
        buffers.pop(op->name);

        // Second, if we decided it would be useful to deinterleave
        // the storage of this buffer, do so now.
        if (deinterleave) {
            deinterleave_buffers.push(op->name, true);
            body = mutate(op->body);
            deinterleave_buffers.pop(op->name);
        }

        aligned_buffer_access.pop(op->name);

        if (!body.same_as(op->body) || !condition.same_as(op->condition)) {
            return Allocate::make(op->name, op->type, op->memory_type,
                                  op->extents, condition, body,
                                  op->new_expr, op->free_function);
        } else {
            return op;
        }
    }

    Stmt visit(const Store *op) override {
        Expr predicate = mutate(op->predicate);
        Expr value = mutate(op->value);
        Expr index = mutate(op->index);

        if (buffers.contains(op->name)) {
            // When inspecting the stores to a buffer, update the state.
            BufferState &state = buffers.ref(op->name);
            if (!is_const_one(predicate) || !op->value.type().is_vector()) {
                // TODO(psuriana): This store is predicated. Mark the buffer as
                // not interleaved for now.
                state = BufferState::NotInterleaved;
            } else if (yields_removable_interleave(value)) {
                // The value yields a removable interleave. If we aren't tracking
                // this buffer, mark it as interleaved.
                if (state == BufferState::Unknown) {
                    state = BufferState::Interleaved;
                }
            } else if (!yields_interleave(value)) {
                // The value does not yield an interleave. Mark the
                // buffer as not interleaved.
                state = BufferState::NotInterleaved;
            } else {
                // If the buffer yields an interleave, but is not an
                // interleave itself, we don't want to change the
                // buffer state.
            }
            internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope");
            bool &aligned_accesses = aligned_buffer_access.ref(op->name);
            int64_t aligned_offset = 0;

            if (!alignment_analyzer.is_aligned(op, &aligned_offset)) {
                aligned_accesses = false;
            }
        }
        if (deinterleave_buffers.contains(op->name)) {
            // We're deinterleaving this buffer, remove the interleave
            // from the store.
            internal_assert(is_const_one(predicate)) << "The store shouldn't have been predicated.\n";
            value = remove_interleave(value);
        }

        if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
            return op;
        } else {
            return Store::make(op->name, value, index, op->param, predicate, op->alignment);
        }
    }

    Expr visit(const Load *op) override {
        if (buffers.contains(op->name)) {
            if ((op->type.lanes() * op->type.bits()) % (native_vector_bits * 2) == 0) {
                // This is a double vector load, we might be able to
                // deinterleave the storage of this buffer.
                // We don't want to actually do anything to the buffer
                // state here. We know we can interleave the load if
                // necessary, but we don't want to cause it to be
                // interleaved unless it is a useful improvement,
                // which is only true if any of the stores are
                // actually interleaved (and don't just yield an
                // interleave).
                internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope");
                bool &aligned_accesses = aligned_buffer_access.ref(op->name);
                int64_t aligned_offset = 0;

                if (!alignment_analyzer.is_aligned(op, &aligned_offset)) {
                    aligned_accesses = false;
                }
            } else {
                // This is not a double vector load, so we can't
                // deinterleave the storage of this buffer.
                BufferState &state = buffers.ref(op->name);
                state = BufferState::NotInterleaved;
            }
        }
        Expr expr = IRMutator::visit(op);
        if (deinterleave_buffers.contains(op->name)) {
            expr = native_interleave(expr);
        }
        return expr;
    }

    using IRMutator::visit;

public:
    EliminateInterleaves(int native_vector_bytes)
        : native_vector_bits(native_vector_bytes * 8), alignment_analyzer(native_vector_bytes) {
    }
};

// After eliminating interleaves, there may be some that remain. This
// mutator attempts to replace interleaves paired with other
// operations that do not require an interleave. It's important to do
// this after all other efforts to eliminate the interleaves,
// otherwise this might eat some interleaves that could have cancelled
// with other operations.
class FuseInterleaves : public IRMutator {
    Expr visit(const Call *op) override {
        // This is a list of {f, g} pairs that if the first operation
        // is interleaved, interleave(f(x)) is equivalent to g(x).
        static const std::vector<std::pair<string, string>> non_deinterleaving_alts = {
            {"halide.hexagon.zxt.vub", "halide.hexagon.unpack.vub"},
            {"halide.hexagon.sxt.vb", "halide.hexagon.unpack.vb"},
            {"halide.hexagon.zxt.vuh", "halide.hexagon.unpack.vuh"},
            {"halide.hexagon.sxt.vh", "halide.hexagon.unpack.vh"},
        };

        if (is_native_interleave(op)) {
            if (const Call *arg = op->args[0].as<Call>()) {
                for (const auto &i : non_deinterleaving_alts) {
                    if (arg->name == i.first) {
                        std::vector<Expr> args = arg->args;
                        for (Expr &j : args) {
                            j = mutate(j);
                        }
                        return Call::make(op->type, i.second, args, Call::PureExtern);
                    }
                }
            }
        }

        return IRMutator::visit(op);
    }

    using IRMutator::visit;
};

// Distribute constant RHS widening shift lefts as multiplies.
// TODO: This is an extremely unfortunate mess. I think the better
// solution is for the simplifier to distribute constant multiplications
// instead of factoring them, and then this logic is unnecessary (find_mpy_ops
// would need to handle shifts, but that's easy).
// Another possibility would be adding a widening_mul_add intrinsic that takes
// a list of pairs of operands, and computes a widening sum of widening multiplies
// of these pairs. FindIntrinsics could aggressively rewrite shifts as
// widening_mul_add operands.
class DistributeShiftsAsMuls : public IRMutator {
private:
    static bool is_cast(const Expr &e, Type value_t) {
        if (const Cast *cast = e.as<Cast>()) {
            return cast->value.type() == value_t;
        }
        return false;
    }

    static Expr distribute(const Expr &a, const Expr &b) {
        if (const Add *add = a.as<Add>()) {
            return Add::make(distribute(add->a, b), distribute(add->b, b));
        } else if (const Sub *sub = a.as<Sub>()) {
            Expr sub_a = distribute(sub->a, b);
            Expr sub_b = distribute(sub->b, b);
            Expr negative_sub_b = lossless_negate(sub_b);
            if (negative_sub_b.defined()) {
                return Add::make(sub_a, negative_sub_b);
            } else {
                return Sub::make(sub_a, sub_b);
            }
        } else if (const Cast *cast = a.as<Cast>()) {
            Expr cast_b = lossless_cast(b.type().with_bits(cast->value.type().bits()), b);
            if (cast_b.defined()) {
                Expr mul = widening_mul(cast->value, cast_b);
                if (mul.type().bits() <= cast->type.bits()) {
                    if (mul.type() != cast->type) {
                        mul = Cast::make(cast->type, mul);
                    }
                    return mul;
                }
            }
        } else if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) {
            Expr add_a = Cast::make(add->type, add->args[0]);
            Expr add_b = Cast::make(add->type, add->args[1]);
            add_a = distribute(add_a, b);
            add_b = distribute(add_b, b);
            // If add_a and add_b are the same kind of cast, we should remake a widening add.
            const Cast *add_a_cast = add_a.as<Cast>();
            const Cast *add_b_cast = add_b.as<Cast>();
            if (add_a_cast && add_b_cast &&
                add_a_cast->value.type() == add->args[0].type() &&
                add_b_cast->value.type() == add->args[1].type()) {
                return widening_add(add_a_cast->value, add_b_cast->value);
            } else {
                return Add::make(add_a, add_b);
            }
        } else if (const Call *sub = Call::as_intrinsic(a, {Call::widening_sub})) {
            Expr sub_a = Cast::make(sub->type, sub->args[0]);
            Expr sub_b = Cast::make(sub->type, sub->args[1]);
            sub_a = distribute(sub_a, b);
            sub_b = distribute(sub_b, b);
            Expr negative_sub_b = lossless_negate(sub_b);
            if (negative_sub_b.defined()) {
                sub_b = negative_sub_b;
            }
            // If sub_a and sub_b are the same kind of cast, we should remake a widening sub.
            const Cast *sub_a_cast = sub_a.as<Cast>();
            const Cast *sub_b_cast = sub_b.as<Cast>();
            if (sub_a_cast && sub_b_cast &&
                sub_a_cast->value.type() == sub->args[0].type() &&
                sub_b_cast->value.type() == sub->args[1].type()) {
                if (negative_sub_b.defined()) {
                    return widening_add(sub_a_cast->value, sub_b_cast->value);
                } else {
                    return widening_sub(sub_a_cast->value, sub_b_cast->value);
                }
            } else {
                if (negative_sub_b.defined()) {
                    return Add::make(sub_a, sub_b);
                } else {
                    return Sub::make(sub_a, sub_b);
                }
            }
        } else if (const Call *mul = Call::as_intrinsic(a, {Call::widening_mul})) {
            Expr mul_a = Cast::make(mul->type, mul->args[0]);
            Expr mul_b = Cast::make(mul->type, mul->args[1]);
            mul_a = distribute(mul_a, b);
            if (const Cast *mul_a_cast = mul_a.as<Cast>()) {
                if (mul_a_cast->value.type() == mul->args[0].type()) {
                    return widening_mul(mul_a_cast->value, mul->args[1]);
                }
            }
            mul_b = distribute(mul_b, b);
            if (const Cast *mul_b_cast = mul_b.as<Cast>()) {
                if (mul_b_cast->value.type() == mul->args[1].type()) {
                    return widening_mul(mul->args[0], mul_b_cast->value);
                }
            }
        }
        return simplify(Mul::make(a, b));
    }

    using IRMutator::visit;

    Expr visit(const Call *op) override {
        if (op->is_intrinsic(Call::shift_left)) {
            if (const uint64_t *const_b = as_const_uint(op->args[1])) {
                Expr a = op->args[0];
                // Only rewrite widening shifts.
                const Cast *cast_a = a.as<Cast>();
                bool is_widening_cast = cast_a && cast_a->type.bits() >= cast_a->value.type().bits() * 2;
                if (is_widening_cast || Call::as_intrinsic(a, {Call::widening_add, Call::widening_mul, Call::widening_sub})) {
                    const uint64_t const_m = 1ull << *const_b;
                    Expr b = make_const(a.type(), const_m);
                    return mutate(distribute(a, b));
                }
            }
        } else if (op->is_intrinsic(Call::widening_shift_left)) {
            if (const uint64_t *const_b = as_const_uint(op->args[1])) {
                const uint64_t const_m = 1ull << *const_b;
                Expr b = make_const(op->type, const_m);
                Expr a = Cast::make(op->type, op->args[0]);
                return mutate(distribute(a, b));
            }
        }
        return IRMutator::visit(op);
    }
};

// Try generating vgathers instead of shuffles.
// At present, we request VTCM memory with single page allocation flag for all
// store_in allocations. So it's always safe to generate a vgather.
// Expressions which generate vgathers are of the form:
//     out(x) = lut(foo(x))
// For vgathers out and lut should be in VTCM in a single page.
class ScatterGatherGenerator : public IRMutator {
    Scope<Interval> bounds;
    std::unordered_map<string, const Allocate *> allocations;

    using IRMutator::visit;

    Expr visit(const Call *op) override {
        if (op->is_intrinsic(Call::if_then_else) && op->args[0].type().is_vector()) {
            const Broadcast *b = op->args[0].as<Broadcast>();
            if (!b || b->value.type().is_vector()) {
                return op;
            }
        }
        return IRMutator::visit(op);
    }

    template<typename NodeType, typename T>
    NodeType visit_let(const T *op) {
        // We only care about vector lets.
        if (op->value.type().is_vector()) {
            bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
        }
        NodeType node = IRMutator::visit(op);
        if (op->value.type().is_vector()) {
            bounds.pop(op->name);
        }
        return node;
    }

    Expr visit(const Let *op) override {
        return visit_let<Expr>(op);
    }

    Stmt visit(const LetStmt *op) override {
        return visit_let<Stmt>(op);
    }

    Stmt visit(const Allocate *op) override {
        // Create a map of the allocation
        allocations[op->name] = op;
        return IRMutator::visit(op);
    }

    // Try to match expressions of the form:
    //     out(x) = lut(foo(x))
    // to generate vgathers. Here, out and lut should have
    // store_in(MemoryType::VTCM) directive. If a vgather is found return Call
    // Expr to vgather, otherwise Expr().
    Expr make_gather(const Load *op, Expr dst_base, Expr dst_index) {
        Type ty = op->type;
        const Allocate *alloc = allocations[op->name];
        // The lut should be in VTCM.
        if (!alloc || alloc->memory_type != MemoryType::VTCM) {
            return Expr();
        }
        // HVX has only 16 or 32-bit gathers. Predicated vgathers are not
        // supported yet.
        if (op->index.as<Ramp>() || !is_const_one(op->predicate) || !ty.is_vector() ||
            ty.bits() == 8) {
            return Expr();
        }
        Expr index = mutate(ty.bytes() * op->index);
        Interval index_bounds = bounds_of_expr_in_scope(index, bounds);
        if (ty.bits() == 16 && index_bounds.is_bounded()) {
            Expr index_span = span_of_bounds(index_bounds);
            index_span = common_subexpression_elimination(index_span);
            index_span = simplify(index_span);
            // We need to downcast the index values to 16 bit signed. So all the
            // the indices must be less than 1 << 15.
            if (!can_prove(index_span < std::numeric_limits<int16_t>::max())) {
                return Expr();
            }
        }
        // Calculate the size of the buffer lut in bytes.
        Expr size = ty.bytes();
        for (const auto &extent : alloc->extents) {
            size *= extent;
        }
        Expr src = Variable::make(Handle(), op->name);
        Expr new_index = mutate(cast(ty.with_code(Type::Int), index));
        dst_index = mutate(dst_index);

        return Call::make(ty, Call::hvx_gather, {std::move(dst_base), dst_index, src, size - 1, new_index},
                          Call::Intrinsic);
    }

    // Checks if the Store node can be replaced with a scatter_accumulate.
    // If yes, return new_value to be used for scatter-accumulate, else return
    // the input parameter value.
    Expr is_scatter_acc(const Store *op) {
        Expr lhs = Load::make(op->value.type(), op->name, op->index, Buffer<>(),
                              Parameter(), const_true(op->value.type().lanes()), op->alignment);
        Expr wild = Variable::make(op->value.type(), "*");
        vector<Expr> matches;
        if (expr_match(lhs + wild, op->value, matches) ||
            expr_match(wild + lhs, op->value, matches)) {
            // Scatter accumulate found.
            return matches[0];
        }
        return op->value;
    }

    Stmt visit(const Store *op) override {
        // HVX has only 16 or 32-bit gathers. Predicated vgathers are not
        // supported yet.
        Type ty = op->value.type();
        if (!is_const_one(op->predicate) || !ty.is_vector() || ty.bits() == 8) {
            return IRMutator::visit(op);
        }
        // To use vgathers, the destination address must be VTCM memory.
        const Allocate *alloc = allocations[op->name];
        if (!alloc || alloc->memory_type != MemoryType::VTCM) {
            return IRMutator::visit(op);
        }
        // The source for a gather must also be a buffer in VTCM.
        if (op->index.as<Ramp>() && op->value.as<Load>()) {
            // Check for vgathers
            Expr dst_base = Variable::make(Handle(), op->name);
            Expr dst_index = op->index.as<Ramp>()->base;
            Expr value = make_gather(op->value.as<Load>(), dst_base, dst_index);
            if (value.defined()) {
                // Found a vgather.
                // Function make_gather already mutates all the call arguements,
                // so no need to mutate again.
                return Evaluate::make(value);
            }
        }
        // Check for scatter/scatter-accumulate.
        if (op->index.as<Ramp>()) {
            return IRMutator::visit(op);
        }
        // Calculate the size of the buffer in bytes.
        Expr size = ty.bytes();
        for (const auto &extent : alloc->extents) {
            size *= extent;
        }
        // Check for scatter-acc.
        Expr value = is_scatter_acc(op);
        Call::IntrinsicOp intrinsic = Call::hvx_scatter;
        if (!value.same_as(op->value)) {
            // It's a scatter-accumulate
            intrinsic = Call::hvx_scatter_acc;
        }
        Expr buffer = Variable::make(Handle(), op->name);
        Expr index = mutate(cast(ty.with_code(Type::Int), ty.bytes() * op->index));
        value = mutate(value);
        Stmt scatter = Evaluate::make(Call::make(ty, intrinsic,
                                                 {buffer, size - 1, index, value}, Call::Intrinsic));
        return scatter;
    }
};

// Scatter-Gather instructions on Hexagon are asynchronous and hence require a
// scatter-release store followed by a vector load from the same address. This
// stalls the pipeline untill all previous scatter-gather operations have
// finished. The operations are not ordered with respect to load and store
// operations as well.
class SyncronizationBarriers : public IRMutator {
    // Keep track of all scatter-gather operations in flight which could cause
    // a hazard in the future.
    std::map<string, vector<const Stmt *>> in_flight;
    // Trail of For Blocks to reach a stmt.
    vector<const Stmt *> curr_path;
    // Current Stmt being mutated.
    const Stmt *curr = nullptr;
    // Track where the Stmt generated a scatter-release.
    std::map<const Stmt *, Expr> sync;

    using IRMutator::visit;

    Expr visit(const Call *op) override {
        if (op->is_intrinsic(Call::hvx_scatter) ||
            op->is_intrinsic(Call::hvx_scatter_acc) ||
            op->is_intrinsic(Call::hvx_gather)) {
            string name = op->args[0].as<Variable>()->name;
            // Check if the scatter-gather encountered conflicts with any
            // previous operation. If yes, insert a scatter-release.
            check_hazard(name);
            in_flight[name] = curr_path;
        }
        return IRMutator::visit(op);
    }

    Stmt visit(const For *op) override {
        // Keep trail of the For blocks encoutered.
        curr_path.push_back(curr);
        Stmt s = IRMutator::visit(op);
        curr_path.pop_back();
        return s;
    }

    // Creates entry in sync map for the stmt requiring a
    // scatter-release instruction before it.
    void check_hazard(const string &name) {
        if (in_flight.find(name) == in_flight.end()) {
            return;
        }
        // Sync Needed. Add the scatter-release before the first different For
        // loop lock between the curr_path and the hazard src location.
        size_t min_size = std::min(in_flight[name].size(), curr_path.size());
        size_t i = 0;
        // Find the first different For loop block.
        for (; i < min_size; i++) {
            if (in_flight[name][i] != curr_path[i]) {
                break;
            }
        }
        if (i < curr_path.size()) {
            // Place scatter-release before the first different For loop block.
            sync[curr_path[i]] = Variable::make(Handle(), name);
        } else {
            // Need to add the scatter-release before the curr stmt.
            sync[curr] = Variable::make(Handle(), name);
        }
        in_flight.clear();
    }

    Expr visit(const Load *op) override {
        // Resolve scatter-load hazard.
        check_hazard(op->name);
        return IRMutator::visit(op);
    }

    Stmt visit(const Store *op) override {
        // Resolve scatter-store and gather-store hazards.
        check_hazard(op->name);
        return IRMutator::visit(op);
    }

public:
    using IRMutator::mutate;

    Stmt mutate(const Stmt &s) override {
        curr = &s;
        Stmt new_s = IRMutator::mutate(s);
        // Wrap the stmt with scatter-release if any hazard was detected.
        if (sync.find(&s) != sync.end()) {
            Stmt scatter_sync =
                Evaluate::make(Call::make(Int(32), Call::hvx_scatter_release, {sync[&s]}, Call::Intrinsic));
            return Block::make(scatter_sync, new_s);
        }
        return new_s;
    }
};

}  // namespace

Stmt optimize_hexagon_shuffles(const Stmt &s, int lut_alignment) {
    // Replace indirect and other complicated loads with
    // dynamic_shuffle (vlut) calls.
    return optimize_shuffles(s, lut_alignment);
}

Stmt scatter_gather_generator(Stmt s) {
    // Generate vscatter-vgather instruction if target >= v65
    s = substitute_in_all_lets(s);
    s = ScatterGatherGenerator().mutate(s);
    s = SyncronizationBarriers().mutate(s);
    s = common_subexpression_elimination(s);
    return s;
}

Stmt optimize_hexagon_instructions(Stmt s, const Target &t) {
    // We need to redo intrinsic matching due to simplification that has
    // happened after the end of target independent lowering.
    s = find_intrinsics(s);

    // Hexagon prefers widening shifts to be expressed as multiplies to
    // hopefully hit compound widening multiplies.
    s = DistributeShiftsAsMuls().mutate(s);

    // Pattern match VectorReduce IR node. Handle vector reduce instructions
    // before OptimizePatterns to prevent being mutated by patterns like
    // (v0 + v1 * c) -> add_mpy
    s = VectorReducePatterns().mutate(s);

    // Peephole optimize for Hexagon instructions. These can generate
    // interleaves and deinterleaves alongside the HVX intrinsics.
    s = OptimizePatterns(t).mutate(s);

    // Try to eliminate any redundant interleave/deinterleave pairs.
    s = EliminateInterleaves(t.natural_vector_size(Int(8))).mutate(s);

    // There may be interleaves left over that we can fuse with other
    // operations.
    s = FuseInterleaves().mutate(s);
    return s;
}

}  // namespace Internal
}  // namespace Halide
back to top