Revision 3657cf5f363fd64aeaf06432e62e3960800927b0 authored by Andrew Adams on 26 January 2024, 17:26:12 UTC, committed by GitHub on 26 January 2024, 17:26:12 UTC
* Fix bounds_of_nested_lanes

bounds_of_nested_lanes assumed that one layer of nested vectorization
could be removed at a time. When faced with the expression:

min(ramp(x8(a), x8(b), 5), x40(27))

It panicked, because on the left hand side it reduced the bounds to
x8(a) ... x8(a) + x8(b) * 4, and on the right hand side it reduced the
bounds to 27. It then attempted to take a min of mismatched types.

In general we can't assume that binary operators on nested vectors have
the same nesting structure on both sides, so I just rewrote it to reduce
directly to a scalar.

Fixes #8038
1 parent 4590a09
Raw File
Simplify_Or.cpp
#include "Simplify_Internal.h"

namespace Halide {
namespace Internal {

Expr Simplify::visit(const Or *op, ExprInfo *bounds) {
    if (truths.count(op)) {
        return const_true(op->type.lanes());
    }

    Expr a = mutate(op->a, nullptr);
    Expr b = mutate(op->b, nullptr);

    if (should_commute(a, b)) {
        std::swap(a, b);
    }

    auto rewrite = IRMatcher::rewriter(IRMatcher::or_op(a, b), op->type);

    // clang-format off
    if (EVAL_IN_LAMBDA
        (rewrite(x || true, b) ||
         rewrite(x || false, a) ||
         rewrite(x || x, a) ||

         rewrite((x || y) || x, a) ||
         rewrite(x || (x || y), b) ||
         rewrite((x || y) || y, a) ||
         rewrite(y || (x || y), b) ||

         rewrite(((x || y) || z) || x, a) ||
         rewrite(x || ((x || y) || z), b) ||
         rewrite((z || (x || y)) || x, a) ||
         rewrite(x || (z || (x || y)), b) ||
         rewrite(((x || y) || z) || y, a) ||
         rewrite(y || ((x || y) || z), b) ||
         rewrite((z || (x || y)) || y, a) ||
         rewrite(y || (z || (x || y)), b) ||

         rewrite((x && y) || x, b) ||
         rewrite(x || (x && y), a) ||
         rewrite((x && y) || y, b) ||
         rewrite(y || (x && y), a) ||

         rewrite(x != y || x == y, true) ||
         rewrite(x != y || y == x, true) ||
         rewrite((z || x != y) || x == y, true) ||
         rewrite((z || x != y) || y == x, true) ||
         rewrite((x != y || z) || x == y, true) ||
         rewrite((x != y || z) || y == x, true) ||
         rewrite((z || x == y) || x != y, true) ||
         rewrite((z || x == y) || y != x, true) ||
         rewrite((x == y || z) || x != y, true) ||
         rewrite((x == y || z) || y != x, true) ||
         rewrite(x || !x, true) ||
         rewrite(!x || x, true) ||
         rewrite(y <= x || x < y, true) ||
         rewrite(x != c0 || x == c1, a, c0 != c1) ||
         rewrite(x <= c0 || c1 <= x, true, !is_float(x) && c1 <= c0 + 1) ||
         rewrite(c1 <= x || x <= c0, true, !is_float(x) && c1 <= c0 + 1) ||
         rewrite(x <= c0 || c1 < x, true, c1 <= c0) ||
         rewrite(c1 <= x || x < c0, true, c1 <= c0) ||
         rewrite(x < c0 || c1 < x, true, c1 < c0) ||
         rewrite(c1 < x || x < c0, true, c1 < c0) ||
         rewrite(c0 < x || c1 < x, fold(min(c0, c1)) < x) ||
         rewrite(c0 <= x || c1 <= x, fold(min(c0, c1)) <= x) ||
         rewrite(x < c0 || x < c1, x < fold(max(c0, c1))) ||
         rewrite(x <= c0 || x <= c1, x <= fold(max(c0, c1))))) {
        return rewrite.result;
    }
    // clang-format on

    if (rewrite(broadcast(x, c0) || broadcast(y, c0), broadcast(x || y, c0)) ||
        rewrite((x && (y || z)) || y, (x && z) || y) ||
        rewrite((x && (z || y)) || y, (x && z) || y) ||
        rewrite(y || (x && (y || z)), y || (x && z)) ||
        rewrite(y || (x && (z || y)), y || (x && z)) ||

        rewrite(((y || z) && x) || y, (z && x) || y) ||
        rewrite(((z || y) && x) || y, (z && x) || y) ||
        rewrite(y || ((y || z) && x), y || (z && x)) ||
        rewrite(y || ((z || y) && x), y || (z && x)) ||

        rewrite((x || (y && z)) || y, x || y) ||
        rewrite((x || (z && y)) || y, x || y) ||
        rewrite(y || (x || (y && z)), y || x) ||
        rewrite(y || (x || (z && y)), y || x) ||

        rewrite(((y && z) || x) || y, x || y) ||
        rewrite(((z && y) || x) || y, x || y) ||
        rewrite(y || ((y && z) || x), y || x) ||
        rewrite(y || ((z && y) || x), y || x) ||

        rewrite((x && y) || (x && z), x && (y || z)) ||
        rewrite((x && y) || (z && x), x && (y || z)) ||
        rewrite((y && x) || (x && z), x && (y || z)) ||
        rewrite((y && x) || (z && x), x && (y || z)) ||

        rewrite(x < y || x < z, x < max(y, z)) ||
        rewrite(y < x || z < x, min(y, z) < x) ||
        rewrite(x <= y || x <= z, x <= max(y, z)) ||
        rewrite(y <= x || z <= x, min(y, z) <= x)) {

        return mutate(rewrite.result, bounds);
    }

    if (a.same_as(op->a) &&
        b.same_as(op->b)) {
        return op;
    } else {
        return Or::make(a, b);
    }
}

}  // namespace Internal
}  // namespace Halide
back to top