https://github.com/halide/Halide
Tip revision: f9e4c7878385f43cf88cca23d5bd663233e9e7da authored by Steven Johnson on 27 April 2021, 19:14:54 UTC
Add support for dynamic tensors to hannk (#5942)
Add support for dynamic tensors to hannk (#5942)
Tip revision: f9e4c78
Simplify_Shuffle.cpp
#include "Simplify_Internal.h"
namespace Halide {
namespace Internal {
using std::vector;
Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) {
if (op->is_extract_element() &&
(op->vectors[0].as<Ramp>() ||
op->vectors[0].as<Broadcast>())) {
// Extracting a single lane of a ramp or broadcast
if (const Ramp *r = op->vectors[0].as<Ramp>()) {
return mutate(r->base + op->indices[0] * r->stride, bounds);
} else if (const Broadcast *b = op->vectors[0].as<Broadcast>()) {
return mutate(b->value, bounds);
} else {
internal_error << "Unreachable";
return Expr();
}
}
// Mutate the vectors
vector<Expr> new_vectors;
bool changed = false;
for (const Expr &vector : op->vectors) {
ExprInfo v_bounds;
Expr new_vector = mutate(vector, &v_bounds);
if (!vector.same_as(new_vector)) {
changed = true;
}
if (bounds) {
if (new_vectors.empty()) {
*bounds = v_bounds;
} else {
bounds->min_defined &= v_bounds.min_defined;
bounds->max_defined &= v_bounds.max_defined;
bounds->min = std::min(bounds->min, v_bounds.min);
bounds->max = std::max(bounds->max, v_bounds.max);
bounds->alignment = ModulusRemainder::unify(bounds->alignment, v_bounds.alignment);
}
}
new_vectors.push_back(new_vector);
}
// Try to convert a load with shuffled indices into a
// shuffle of a dense load.
if (const Load *first_load = new_vectors[0].as<Load>()) {
vector<Expr> load_predicates;
vector<Expr> load_indices;
bool unpredicated = true;
for (const Expr &e : new_vectors) {
const Load *load = e.as<Load>();
if (load && load->name == first_load->name) {
load_predicates.push_back(load->predicate);
load_indices.push_back(load->index);
unpredicated = unpredicated && is_const_one(load->predicate);
} else {
break;
}
}
if (load_indices.size() == new_vectors.size()) {
Type t = load_indices[0].type().with_lanes(op->indices.size());
Expr shuffled_index = Shuffle::make(load_indices, op->indices);
ExprInfo shuffled_index_info;
shuffled_index = mutate(shuffled_index, &shuffled_index_info);
if (shuffled_index.as<Ramp>()) {
ExprInfo base_info;
if (const Ramp *r = shuffled_index.as<Ramp>()) {
mutate(r->base, &base_info);
}
ModulusRemainder alignment =
ModulusRemainder::intersect(base_info.alignment, shuffled_index_info.alignment);
Expr shuffled_predicate;
if (unpredicated) {
shuffled_predicate = const_true(t.lanes());
} else {
shuffled_predicate = Shuffle::make(load_predicates, op->indices);
shuffled_predicate = mutate(shuffled_predicate, nullptr);
}
t = first_load->type;
t = t.with_lanes(op->indices.size());
return Load::make(t, first_load->name, shuffled_index, first_load->image,
first_load->param, shuffled_predicate, alignment);
}
}
}
// Try to collapse a shuffle of broadcasts into a single
// broadcast. Note that it doesn't matter what the indices
// are.
const Broadcast *b1 = new_vectors[0].as<Broadcast>();
if (b1) {
bool can_collapse = true;
for (size_t i = 1; i < new_vectors.size() && can_collapse; i++) {
if (const Broadcast *b2 = new_vectors[i].as<Broadcast>()) {
Expr check = mutate(b1->value - b2->value, nullptr);
can_collapse &= is_const_zero(check);
} else {
can_collapse = false;
}
}
if (can_collapse) {
if (op->indices.size() == 1) {
return b1->value;
} else {
return Broadcast::make(b1->value, op->indices.size());
}
}
}
if (op->is_interleave()) {
int terms = (int)new_vectors.size();
// Try to collapse an interleave of ramps into a single ramp.
const Ramp *r = new_vectors[0].as<Ramp>();
if (r) {
bool can_collapse = true;
for (size_t i = 1; i < new_vectors.size() && can_collapse; i++) {
// If we collapse these terms into a single ramp,
// the new stride is going to be the old stride
// divided by the number of terms, so the
// difference between two adjacent terms in the
// interleave needs to be a broadcast of the new
// stride.
Expr diff = mutate(new_vectors[i] - new_vectors[i - 1], nullptr);
const Broadcast *b = diff.as<Broadcast>();
if (b) {
Expr check = mutate(b->value * terms - r->stride, nullptr);
can_collapse &= is_const_zero(check);
} else {
can_collapse = false;
}
}
if (can_collapse) {
return mutate(Ramp::make(r->base, r->stride / terms, r->lanes * terms), bounds);
}
}
// Try to collapse an interleave of slices of vectors from
// the same vector into a single vector.
if (const Shuffle *first_shuffle = new_vectors[0].as<Shuffle>()) {
if (first_shuffle->is_slice()) {
bool can_collapse = true;
for (size_t i = 0; i < new_vectors.size() && can_collapse; i++) {
const Shuffle *i_shuffle = new_vectors[i].as<Shuffle>();
// Check that the current shuffle is a slice...
if (!i_shuffle || !i_shuffle->is_slice()) {
can_collapse = false;
break;
}
// ... and that it is a slice in the right place...
// If the shuffle is a single element, we don't care what the stride is.
if (i_shuffle->slice_begin() != (int)i ||
(i_shuffle->indices.size() != 1 && i_shuffle->slice_stride() != terms)) {
can_collapse = false;
break;
}
if (i > 0) {
// ... and that the vectors being sliced are the same.
if (first_shuffle->vectors.size() != i_shuffle->vectors.size()) {
can_collapse = false;
break;
}
for (size_t j = 0; j < first_shuffle->vectors.size() && can_collapse; j++) {
if (!equal(first_shuffle->vectors[j], i_shuffle->vectors[j])) {
can_collapse = false;
}
}
}
}
if (can_collapse) {
// It's possible the slices didn't use all of the vector, in which case we need to slice it.
Expr result = Shuffle::make_concat(first_shuffle->vectors);
if (result.type().lanes() != op->type.lanes()) {
result = Shuffle::make_slice(result, 0, 1, op->type.lanes());
}
return result;
}
}
}
} else if (op->is_concat()) {
// Try to collapse a concat of ramps into a single ramp.
const Ramp *r = new_vectors[0].as<Ramp>();
if (r) {
bool can_collapse = true;
for (size_t i = 1; i < new_vectors.size() && can_collapse; i++) {
Expr diff;
if (new_vectors[i].type().lanes() == new_vectors[i - 1].type().lanes()) {
diff = mutate(new_vectors[i] - new_vectors[i - 1], nullptr);
}
const Broadcast *b = diff.as<Broadcast>();
if (b) {
Expr check = mutate(b->value - r->stride * new_vectors[i - 1].type().lanes(), nullptr);
can_collapse &= is_const_zero(check);
} else {
can_collapse = false;
}
}
if (can_collapse) {
return Ramp::make(r->base, r->stride, op->indices.size());
}
}
// Try to collapse a concat of scalars into a ramp.
if (new_vectors[0].type().is_scalar() && new_vectors[1].type().is_scalar()) {
bool can_collapse = true;
Expr stride = mutate(new_vectors[1] - new_vectors[0], nullptr);
for (size_t i = 1; i < new_vectors.size() && can_collapse; i++) {
if (!new_vectors[i].type().is_scalar()) {
can_collapse = false;
break;
}
Expr check = mutate(new_vectors[i] - new_vectors[i - 1] - stride, nullptr);
if (!is_const_zero(check)) {
can_collapse = false;
}
}
if (can_collapse) {
return Ramp::make(new_vectors[0], stride, op->indices.size());
}
}
}
// Pull a widening cast outside of a slice
if (new_vectors.size() == 1 &&
op->type.lanes() < new_vectors[0].type().lanes()) {
if (const Cast *cast = new_vectors[0].as<Cast>()) {
if (cast->type.bits() > cast->value.type().bits()) {
return mutate(Cast::make(cast->type.with_lanes(op->type.lanes()),
Shuffle::make({cast->value}, op->indices)),
bounds);
}
}
}
if (op->is_slice() && (new_vectors.size() == 1)) {
if (const Shuffle *inner_shuffle = new_vectors[0].as<Shuffle>()) {
// Try to collapse a slice of slice.
if (inner_shuffle->is_slice() && (inner_shuffle->vectors.size() == 1)) {
// Indices of the slice are ramp, so nested slice is a1 * (a2 * x + b2) + b1 =
// = a1 * a2 * x + a1 * b2 + b1.
return Shuffle::make_slice(inner_shuffle->vectors[0],
op->slice_begin() * inner_shuffle->slice_stride() + inner_shuffle->slice_begin(),
op->slice_stride() * inner_shuffle->slice_stride(),
op->indices.size());
}
// Check if we really need to concat all vectors before slicing.
if (inner_shuffle->is_concat()) {
int slice_min = op->indices.front();
int slice_max = op->indices.back();
int concat_index = 0;
int new_slice_start = -1;
vector<Expr> new_concat_vectors;
for (const auto &v : inner_shuffle->vectors) {
// Check if current concat vector overlaps with slice.
if ((concat_index >= slice_min && concat_index <= slice_max) ||
((concat_index + v.type().lanes() - 1) >= slice_min && (concat_index + v.type().lanes() - 1) <= slice_max)) {
if (new_slice_start < 0) {
new_slice_start = concat_index;
}
new_concat_vectors.push_back(v);
}
concat_index += v.type().lanes();
}
if (new_concat_vectors.size() < inner_shuffle->vectors.size()) {
return Shuffle::make_slice(Shuffle::make_concat(new_concat_vectors), op->slice_begin() - new_slice_start, op->slice_stride(), op->indices.size());
}
}
}
}
if (!changed) {
return op;
} else {
return Shuffle::make(new_vectors, op->indices);
}
}
template<typename T>
Expr Simplify::hoist_slice_vector(Expr e) {
const T *op = e.as<T>();
internal_assert(op);
const Shuffle *shuffle_a = op->a.template as<Shuffle>();
const Shuffle *shuffle_b = op->b.template as<Shuffle>();
internal_assert(shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice());
if (shuffle_a->indices != shuffle_b->indices) {
return e;
}
const std::vector<Expr> &slices_a = shuffle_a->vectors;
const std::vector<Expr> &slices_b = shuffle_b->vectors;
if (slices_a.size() != slices_b.size()) {
return e;
}
for (size_t i = 0; i < slices_a.size(); i++) {
if (slices_a[i].type() != slices_b[i].type()) {
return e;
}
}
vector<Expr> new_slices;
for (size_t i = 0; i < slices_a.size(); i++) {
new_slices.push_back(T::make(slices_a[i], slices_b[i]));
}
return Shuffle::make(new_slices, shuffle_a->indices);
}
template Expr Simplify::hoist_slice_vector<Add>(Expr);
template Expr Simplify::hoist_slice_vector<Sub>(Expr);
template Expr Simplify::hoist_slice_vector<Mul>(Expr);
template Expr Simplify::hoist_slice_vector<Min>(Expr);
template Expr Simplify::hoist_slice_vector<Max>(Expr);
} // namespace Internal
} // namespace Halide