https://github.com/halide/Halide
Tip revision: f9e4c7878385f43cf88cca23d5bd663233e9e7da authored by Steven Johnson on 27 April 2021, 19:14:54 UTC
Add support for dynamic tensors to hannk (#5942)
Add support for dynamic tensors to hannk (#5942)
Tip revision: f9e4c78
IRMatch.h
#ifndef HALIDE_IR_MATCH_H
#define HALIDE_IR_MATCH_H
/** \file
* Defines a method to match a fragment of IR against a pattern containing wildcards
*/
#include <map>
#include <random>
#include <set>
#include <vector>
#include "IR.h"
#include "IREquality.h"
#include "IROperator.h"
namespace Halide {
namespace Internal {
/** Does the first expression have the same structure as the second?
* Variables in the first expression with the name * are interpreted
* as wildcards, and their matching equivalent in the second
* expression is placed in the vector give as the third argument.
* Wildcards require the types to match. For the type bits and width,
* a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
* integer vectors of any width (including scalars), and a UInt(0, 0)
* will match any unsigned integer type.
*
* For example:
\code
Expr x = Variable::make(Int(32), "*");
match(x + x, 3 + (2*k), result)
\endcode
* should return true, and set result[0] to 3 and
* result[1] to 2*k.
*/
bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
/** Does the first expression have the same structure as the second?
* Variables are matched consistently. The first time a variable is
* matched, it assumes the value of the matching part of the second
* expression. Subsequent matches must be equal to the first match.
*
* For example:
\code
Var x("x"), y("y");
match(x*(x + y), a*(a + b), result)
\endcode
* should return true, and set result["x"] = a, and result["y"] = b.
*/
bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
/** Rewrite the expression x to have `lanes` lanes. This is useful
* for substituting the results of expr_match into a pattern expression. */
Expr with_lanes(const Expr &x, int lanes);
void expr_match_test();
/** An alternative template-metaprogramming approach to expression
* matching. Potentially more efficient. We lift the expression
* pattern into a type, and then use force-inlined functions to
* generate efficient matching and reconstruction code for any
* pattern. Pattern elements are either one of the classes in the
* namespace IRMatcher, or are non-null Exprs (represented as
* BaseExprNode &).
*
* Pattern elements that are fully specified by their pattern can be
* built into an expression using the make method. Some patterns,
* such as a broadcast that matches any number of lanes, don't have
* enough information to recreate an Expr.
*/
namespace IRMatcher {
constexpr int max_wild = 6;
static const halide_type_t i64_type = {halide_type_int, 64, 1};
/** To save stack space, the matcher objects are largely stateless and
* immutable. This state object is built up during matching and then
* consumed when constructing a replacement Expr.
*/
struct MatcherState {
const BaseExprNode *bindings[max_wild];
halide_scalar_value_t bound_const[max_wild];
// values of the lanes field with special meaning.
static constexpr uint16_t signed_integer_overflow = 0x8000;
static constexpr uint16_t special_values_mask = 0x8000; // currently only one
halide_type_t bound_const_type[max_wild];
HALIDE_ALWAYS_INLINE
void set_binding(int i, const BaseExprNode &n) noexcept {
bindings[i] = &n;
}
HALIDE_ALWAYS_INLINE
const BaseExprNode *get_binding(int i) const noexcept {
return bindings[i];
}
HALIDE_ALWAYS_INLINE
void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
bound_const[i].u.i64 = s;
bound_const_type[i] = t;
}
HALIDE_ALWAYS_INLINE
void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
bound_const[i].u.u64 = u;
bound_const_type[i] = t;
}
HALIDE_ALWAYS_INLINE
void set_bound_const(int i, double f, halide_type_t t) noexcept {
bound_const[i].u.f64 = f;
bound_const_type[i] = t;
}
HALIDE_ALWAYS_INLINE
void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
bound_const[i] = val;
bound_const_type[i] = t;
}
HALIDE_ALWAYS_INLINE
void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
val = bound_const[i];
type = bound_const_type[i];
}
HALIDE_ALWAYS_INLINE
// NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
MatcherState() noexcept {
}
};
template<typename T,
typename = typename std::remove_reference<T>::type::pattern_tag>
struct enable_if_pattern {
struct type {};
};
template<typename T>
struct bindings {
constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
};
inline HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty) {
const uint16_t flags = ty.lanes & MatcherState::special_values_mask;
ty.lanes &= ~MatcherState::special_values_mask;
if (flags & MatcherState::signed_integer_overflow) {
return make_signed_integer_overflow(ty);
}
// unreachable
return Expr();
}
HALIDE_ALWAYS_INLINE
Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty) {
halide_type_t scalar_type = ty;
if (scalar_type.lanes & MatcherState::special_values_mask) {
return make_const_special_expr(scalar_type);
}
const int lanes = scalar_type.lanes;
scalar_type.lanes = 1;
Expr e;
switch (scalar_type.code) {
case halide_type_int:
e = IntImm::make(scalar_type, val.u.i64);
break;
case halide_type_uint:
e = UIntImm::make(scalar_type, val.u.u64);
break;
case halide_type_float:
case halide_type_bfloat:
e = FloatImm::make(scalar_type, val.u.f64);
break;
default:
// Unreachable
return Expr();
}
if (lanes > 1) {
e = Broadcast::make(e, lanes);
}
return e;
}
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
// A fast version of expression equality that assumes a well-typed non-null expression tree.
HALIDE_ALWAYS_INLINE
bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
// Early out
return (&a == &b) ||
((a.type == b.type) &&
(a.node_type == b.node_type) &&
equal_helper(a, b));
}
// A pattern that matches a specific expression
struct SpecificExpr {
struct pattern_tag {};
constexpr static uint32_t binds = 0;
// What is the weakest and strongest IR node this could possibly be
constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
constexpr static IRNodeType max_node_type = IRNodeType::Shuffle;
constexpr static bool canonical = true;
const BaseExprNode &expr;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
return equal(expr, e);
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
return Expr(&expr);
}
constexpr static bool foldable = false;
};
inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
s << Expr(&e.expr);
return s;
}
template<int i>
struct WildConstInt {
struct pattern_tag {};
constexpr static uint32_t binds = 1 << i;
constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
constexpr static IRNodeType max_node_type = IRNodeType::IntImm;
constexpr static bool canonical = true;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
const BaseExprNode *op = &e;
if (op->node_type == IRNodeType::Broadcast) {
op = ((const Broadcast *)op)->value.get();
}
if (op->node_type != IRNodeType::IntImm) {
return false;
}
int64_t value = ((const IntImm *)op)->value;
if (bound & binds) {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return (halide_type_t)e.type == type && value == val.u.i64;
}
state.set_bound_const(i, value, e.type);
return true;
}
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
if (bound & binds) {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return type == i64_type && value == val.u.i64;
}
state.set_bound_const(i, value, i64_type);
return true;
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return make_const_expr(val, type);
}
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
state.get_bound_const(i, val, ty);
}
};
template<int i>
std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
s << "ci" << i;
return s;
}
template<int i>
struct WildConstUInt {
struct pattern_tag {};
constexpr static uint32_t binds = 1 << i;
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
const BaseExprNode *op = &e;
if (op->node_type == IRNodeType::Broadcast) {
op = ((const Broadcast *)op)->value.get();
}
if (op->node_type != IRNodeType::UIntImm) {
return false;
}
uint64_t value = ((const UIntImm *)op)->value;
if (bound & binds) {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return (halide_type_t)e.type == type && value == val.u.u64;
}
state.set_bound_const(i, value, e.type);
return true;
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return make_const_expr(val, type);
}
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
state.get_bound_const(i, val, ty);
}
};
template<int i>
std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
s << "cu" << i;
return s;
}
template<int i>
struct WildConstFloat {
struct pattern_tag {};
constexpr static uint32_t binds = 1 << i;
constexpr static IRNodeType min_node_type = IRNodeType::FloatImm;
constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
constexpr static bool canonical = true;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
const BaseExprNode *op = &e;
if (op->node_type == IRNodeType::Broadcast) {
op = ((const Broadcast *)op)->value.get();
}
if (op->node_type != IRNodeType::FloatImm) {
return false;
}
double value = ((const FloatImm *)op)->value;
if (bound & binds) {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return (halide_type_t)e.type == type && value == val.u.f64;
}
state.set_bound_const(i, value, e.type);
return true;
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return make_const_expr(val, type);
}
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
state.get_bound_const(i, val, ty);
}
};
template<int i>
std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
s << "cf" << i;
return s;
}
// Matches and binds to any constant Expr. Does not support constant-folding.
template<int i>
struct WildConst {
struct pattern_tag {};
constexpr static uint32_t binds = 1 << i;
constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
constexpr static bool canonical = true;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
const BaseExprNode *op = &e;
if (op->node_type == IRNodeType::Broadcast) {
op = ((const Broadcast *)op)->value.get();
}
switch (op->node_type) {
case IRNodeType::IntImm:
return WildConstInt<i>().template match<bound>(e, state);
case IRNodeType::UIntImm:
return WildConstUInt<i>().template match<bound>(e, state);
case IRNodeType::FloatImm:
return WildConstFloat<i>().template match<bound>(e, state);
default:
return false;
}
}
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
return WildConstInt<i>().template match<bound>(e, state);
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t val;
halide_type_t type;
state.get_bound_const(i, val, type);
return make_const_expr(val, type);
}
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
state.get_bound_const(i, val, ty);
}
};
template<int i>
std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
s << "c" << i;
return s;
}
// Matches and binds to any Expr
template<int i>
struct Wild {
struct pattern_tag {};
constexpr static uint32_t binds = 1 << (i + 16);
constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
constexpr static IRNodeType max_node_type = StrongestExprNodeType;
constexpr static bool canonical = true;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (bound & binds) {
return equal(*state.get_binding(i), e);
}
state.set_binding(i, e);
return true;
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
return state.get_binding(i);
}
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
const auto *e = state.get_binding(i);
ty = e->type;
switch (e->node_type) {
case IRNodeType::UIntImm:
val.u.u64 = ((const UIntImm *)e)->value;
return;
case IRNodeType::IntImm:
val.u.i64 = ((const IntImm *)e)->value;
return;
case IRNodeType::FloatImm:
val.u.f64 = ((const FloatImm *)e)->value;
return;
default:
// The function is noexcept, so silent failure. You
// shouldn't be calling this if you haven't already
// checked it's going to be a constant (e.g. with
// is_const, or because you manually bound a constant Expr
// to the state).
val.u.u64 = 0;
}
}
};
template<int i>
std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
s << "_" << i;
return s;
}
// Matches a specific constant or broadcast of that constant. The
// constant must be representable as an int64_t.
struct IntLiteral {
struct pattern_tag {};
int64_t v;
constexpr static uint32_t binds = 0;
constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
constexpr static bool canonical = true;
HALIDE_ALWAYS_INLINE
explicit IntLiteral(int64_t v)
: v(v) {
}
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
const BaseExprNode *op = &e;
if (e.node_type == IRNodeType::Broadcast) {
op = ((const Broadcast *)op)->value.get();
}
switch (op->node_type) {
case IRNodeType::IntImm:
return ((const IntImm *)op)->value == (int64_t)v;
case IRNodeType::UIntImm:
return ((const UIntImm *)op)->value == (uint64_t)v;
case IRNodeType::FloatImm:
return ((const FloatImm *)op)->value == (double)v;
default:
return false;
}
}
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
return v == val;
}
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
return v == b.v;
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
return make_const(type_hint, v);
}
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
// Assume type is already correct
switch (ty.code) {
case halide_type_int:
val.u.i64 = v;
break;
case halide_type_uint:
val.u.u64 = (uint64_t)v;
break;
case halide_type_float:
case halide_type_bfloat:
val.u.f64 = (double)v;
break;
default:
// Unreachable
;
}
}
};
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t) {
return t.v;
}
// Convert a provided pattern, expr, or constant int into the internal
// representation we use in the matcher trees.
template<typename T,
typename = typename std::decay<T>::type::pattern_tag>
HALIDE_ALWAYS_INLINE T pattern_arg(T t) {
return t;
}
HALIDE_ALWAYS_INLINE
IntLiteral pattern_arg(int64_t x) {
return IntLiteral{x};
}
template<typename T>
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr() {
static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
"Exprs are captured by reference by IRMatcher objects and so must be lvalues");
}
HALIDE_ALWAYS_INLINE SpecificExpr pattern_arg(const Expr &e) {
return {*e.get()};
}
// Helpers to deref SpecificExprs to const BaseExprNode & rather than
// passing them by value anywhere (incurring lots of refcounting)
template<typename T,
// T must be a pattern node
typename = typename std::decay<T>::type::pattern_tag,
// But T may not be SpecificExpr
typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
HALIDE_ALWAYS_INLINE T unwrap(T t) {
return t;
}
HALIDE_ALWAYS_INLINE
const BaseExprNode &unwrap(const SpecificExpr &e) {
return e.expr;
}
inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
s << op.v;
return s;
}
template<typename Op>
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept;
template<typename Op>
uint64_t constant_fold_bin_op(halide_type_t &, uint64_t, uint64_t) noexcept;
template<typename Op>
double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
constexpr bool commutative(IRNodeType t) {
return (t == IRNodeType::Add ||
t == IRNodeType::Mul ||
t == IRNodeType::And ||
t == IRNodeType::Or ||
t == IRNodeType::Min ||
t == IRNodeType::Max ||
t == IRNodeType::EQ ||
t == IRNodeType::NE);
}
// Matches one of the binary operators
template<typename Op, typename A, typename B>
struct BinOp {
struct pattern_tag {};
A a;
B b;
constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask;
constexpr static IRNodeType min_node_type = Op::_node_type;
constexpr static IRNodeType max_node_type = Op::_node_type;
// For commutative bin ops, we expect the weaker IR node type on
// the right. That is, for the rule to be canonical it must be
// possible that A is at least as strong as B.
constexpr static bool canonical =
A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != Op::_node_type) {
return false;
}
const Op &op = (const Op &)e;
return (a.template match<bound>(*op.a.get(), state) &&
b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
}
template<uint32_t bound, typename Op2, typename A2, typename B2>
HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
return (std::is_same<Op, Op2>::value &&
a.template match<bound>(unwrap(op.a), state) &&
b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
}
constexpr static bool foldable = A::foldable && B::foldable;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
halide_scalar_value_t val_a, val_b;
if (std::is_same<A, IntLiteral>::value) {
b.make_folded_const(val_b, ty, state);
if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
(std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
// Short circuit
val = val_b;
return;
}
const uint16_t l = ty.lanes;
a.make_folded_const(val_a, ty, state);
ty.lanes |= l; // Make sure the overflow bits are sticky
} else {
a.make_folded_const(val_a, ty, state);
if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
(std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
// Short circuit
val = val_a;
return;
}
const uint16_t l = ty.lanes;
b.make_folded_const(val_b, ty, state);
ty.lanes |= l;
}
switch (ty.code) {
case halide_type_int:
val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
break;
case halide_type_uint:
val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
break;
case halide_type_float:
case halide_type_bfloat:
val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
break;
default:
// unreachable
;
}
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
Expr ea, eb;
if (std::is_same<A, IntLiteral>::value) {
eb = b.make(state, type_hint);
ea = a.make(state, eb.type());
} else {
ea = a.make(state, type_hint);
eb = b.make(state, ea.type());
}
// We sometimes mix vectors and scalars in the rewrite rules,
// so insert a broadcast if necessary.
if (ea.type().is_vector() && !eb.type().is_vector()) {
eb = Broadcast::make(eb, ea.type().lanes());
}
if (eb.type().is_vector() && !ea.type().is_vector()) {
ea = Broadcast::make(ea, eb.type().lanes());
}
return Op::make(std::move(ea), std::move(eb));
}
};
template<typename Op>
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept;
template<typename Op>
uint64_t constant_fold_cmp_op(uint64_t, uint64_t) noexcept;
template<typename Op>
uint64_t constant_fold_cmp_op(double, double) noexcept;
// Matches one of the comparison operators
template<typename Op, typename A, typename B>
struct CmpOp {
struct pattern_tag {};
A a;
B b;
constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask;
constexpr static IRNodeType min_node_type = Op::_node_type;
constexpr static IRNodeType max_node_type = Op::_node_type;
constexpr static bool canonical = (A::canonical &&
B::canonical &&
(!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
(Op::_node_type != IRNodeType::GE) &&
(Op::_node_type != IRNodeType::GT));
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != Op::_node_type) {
return false;
}
const Op &op = (const Op &)e;
return (a.template match<bound>(*op.a.get(), state) &&
b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
}
template<uint32_t bound, typename Op2, typename A2, typename B2>
HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
return (std::is_same<Op, Op2>::value &&
a.template match<bound>(unwrap(op.a), state) &&
b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
}
constexpr static bool foldable = A::foldable && B::foldable;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
halide_scalar_value_t val_a, val_b;
// If one side is an untyped const, evaluate the other side first to get a type hint.
if (std::is_same<A, IntLiteral>::value) {
b.make_folded_const(val_b, ty, state);
const uint16_t l = ty.lanes;
a.make_folded_const(val_a, ty, state);
ty.lanes |= l;
} else {
a.make_folded_const(val_a, ty, state);
const uint16_t l = ty.lanes;
b.make_folded_const(val_b, ty, state);
ty.lanes |= l;
}
switch (ty.code) {
case halide_type_int:
val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
break;
case halide_type_uint:
val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
break;
case halide_type_float:
case halide_type_bfloat:
val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
break;
default:
// unreachable
;
}
ty.code = halide_type_uint;
ty.bits = 1;
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
// If one side is an untyped const, evaluate the other side first to get a type hint.
Expr ea, eb;
if (std::is_same<A, IntLiteral>::value) {
eb = b.make(state, {});
ea = a.make(state, eb.type());
} else {
ea = a.make(state, {});
eb = b.make(state, ea.type());
}
// We sometimes mix vectors and scalars in the rewrite rules,
// so insert a broadcast if necessary.
if (ea.type().is_vector() && !eb.type().is_vector()) {
eb = Broadcast::make(eb, ea.type().lanes());
}
if (eb.type().is_vector() && !ea.type().is_vector()) {
ea = Broadcast::make(ea, eb.type().lanes());
}
return Op::make(std::move(ea), std::move(eb));
}
};
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
s << "(" << op.a << " + " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
s << "(" << op.a << " - " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
s << "(" << op.a << " * " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
s << "(" << op.a << " / " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
s << "(" << op.a << " && " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
s << "(" << op.a << " || " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
s << "min(" << op.a << ", " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
s << "max(" << op.a << ", " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
s << "(" << op.a << " <= " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
s << "(" << op.a << " < " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
s << "(" << op.a << " >= " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
s << "(" << op.a << " > " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
s << "(" << op.a << " == " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
s << "(" << op.a << " != " << op.b << ")";
return s;
}
template<typename A, typename B>
std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
s << "(" << op.a << " % " << op.b << ")";
return s;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return IRMatcher::operator+(a, b);
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Add>(halide_type_t &t, int64_t a, int64_t b) noexcept {
t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
int dead_bits = 64 - t.bits;
// Drop the high bits then sign-extend them back
return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Add>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
uint64_t ones = (uint64_t)(-1);
return (a + b) & (ones >> (64 - t.bits));
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
return a + b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return IRMatcher::operator-(a, b);
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Sub>(halide_type_t &t, int64_t a, int64_t b) noexcept {
t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
// Drop the high bits then sign-extend them back
int dead_bits = 64 - t.bits;
return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Sub>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
uint64_t ones = (uint64_t)(-1);
return (a - b) & (ones >> (64 - t.bits));
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
return a - b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return IRMatcher::operator*(a, b);
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Mul>(halide_type_t &t, int64_t a, int64_t b) noexcept {
t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
int dead_bits = 64 - t.bits;
// Drop the high bits then sign-extend them back
return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Mul>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
uint64_t ones = (uint64_t)(-1);
return (a * b) & (ones >> (64 - t.bits));
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
return a * b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
return IRMatcher::operator/(a, b);
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Div>(halide_type_t &t, int64_t a, int64_t b) noexcept {
return div_imp(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Div>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
return div_imp(a, b);
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
return div_imp(a, b);
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return IRMatcher::operator%(a, b);
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Mod>(halide_type_t &t, int64_t a, int64_t b) noexcept {
return mod_imp(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Mod>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
return mod_imp(a, b);
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
return mod_imp(a, b);
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return {pattern_arg(a), pattern_arg(b)};
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Min>(halide_type_t &t, int64_t a, int64_t b) noexcept {
return std::min(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Min>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
return std::min(a, b);
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
return std::min(a, b);
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Max>(halide_type_t &t, int64_t a, int64_t b) noexcept {
return std::max(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Max>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
return std::max(a, b);
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
return std::max(a, b);
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
return IRMatcher::operator<(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LT>(int64_t a, int64_t b) noexcept {
return a < b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LT>(uint64_t a, uint64_t b) noexcept {
return a < b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LT>(double a, double b) noexcept {
return a < b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
return IRMatcher::operator>(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GT>(int64_t a, int64_t b) noexcept {
return a > b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GT>(uint64_t a, uint64_t b) noexcept {
return a > b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GT>(double a, double b) noexcept {
return a > b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
return IRMatcher::operator<=(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LE>(int64_t a, int64_t b) noexcept {
return a <= b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LE>(uint64_t a, uint64_t b) noexcept {
return a <= b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LE>(double a, double b) noexcept {
return a <= b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
return IRMatcher::operator>=(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GE>(int64_t a, int64_t b) noexcept {
return a >= b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GE>(uint64_t a, uint64_t b) noexcept {
return a >= b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GE>(double a, double b) noexcept {
return a >= b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
return IRMatcher::operator==(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<EQ>(int64_t a, int64_t b) noexcept {
return a == b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<EQ>(uint64_t a, uint64_t b) noexcept {
return a == b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<EQ>(double a, double b) noexcept {
return a == b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
return IRMatcher::operator!=(a, b);
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<NE>(int64_t a, int64_t b) noexcept {
return a != b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<NE>(uint64_t a, uint64_t b) noexcept {
return a != b;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<NE>(double a, double b) noexcept {
return a != b;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
return IRMatcher::operator||(a, b);
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Or>(halide_type_t &t, int64_t a, int64_t b) noexcept {
return (a | b) & 1;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Or>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
return (a | b) & 1;
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
// Unreachable, as it would be a type mismatch.
return 0;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
return IRMatcher::operator&&(a, b);
}
template<>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<And>(halide_type_t &t, int64_t a, int64_t b) noexcept {
return a & b & 1;
}
template<>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<And>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
return a & b & 1;
}
template<>
HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
// Unreachable
return 0;
}
constexpr inline uint32_t bitwise_or_reduce() {
return 0;
}
template<typename... Args>
constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
return first | bitwise_or_reduce(rest...);
}
constexpr inline bool and_reduce() {
return true;
}
template<typename... Args>
constexpr bool and_reduce(bool first, Args... rest) {
return first && and_reduce(rest...);
}
// TODO: this can be replaced with std::min() once we require C++14 or later
constexpr int const_min(int a, int b) {
return a < b ? a : b;
}
template<typename... Args>
struct Intrin {
struct pattern_tag {};
Call::IntrinsicOp intrin;
std::tuple<Args...> args;
static constexpr uint32_t binds = bitwise_or_reduce((bindings<Args>::mask)...);
constexpr static IRNodeType min_node_type = IRNodeType::Call;
constexpr static IRNodeType max_node_type = IRNodeType::Call;
constexpr static bool canonical = and_reduce((Args::canonical)...);
template<int i,
uint32_t bound,
typename = typename std::enable_if<(i < sizeof...(Args))>::type>
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
using T = decltype(std::get<i>(args));
return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
}
template<int i, uint32_t binds>
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
return true;
}
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != IRNodeType::Call) {
return false;
}
const Call &c = (const Call &)e;
return (c.is_intrinsic(intrin) && match_args<0, bound>(0, c, state));
}
template<int i,
typename = typename std::enable_if<(i < sizeof...(Args))>::type>
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
s << std::get<i>(args);
if (i + 1 < sizeof...(Args)) {
s << ", ";
}
print_args<i + 1>(0, s);
}
template<int i>
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
}
HALIDE_ALWAYS_INLINE
void print_args(std::ostream &s) const {
print_args<0>(0, s);
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
Expr arg0 = std::get<0>(args).make(state, type_hint);
if (intrin == Call::likely) {
return likely(arg0);
} else if (intrin == Call::likely_if_innermost) {
return likely_if_innermost(arg0);
} else if (intrin == Call::abs) {
return abs(arg0);
}
Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
if (intrin == Call::absd) {
return absd(arg0, arg1);
} else if (intrin == Call::widening_add) {
return widening_add(arg0, arg1);
} else if (intrin == Call::widening_sub) {
return widening_sub(arg0, arg1);
} else if (intrin == Call::widening_mul) {
return widening_mul(arg0, arg1);
} else if (intrin == Call::saturating_add) {
return saturating_add(arg0, arg1);
} else if (intrin == Call::saturating_sub) {
return saturating_sub(arg0, arg1);
} else if (intrin == Call::halving_add) {
return halving_add(arg0, arg1);
} else if (intrin == Call::halving_sub) {
return halving_sub(arg0, arg1);
} else if (intrin == Call::rounding_halving_add) {
return rounding_halving_add(arg0, arg1);
} else if (intrin == Call::rounding_halving_sub) {
return rounding_halving_sub(arg0, arg1);
} else if (intrin == Call::shift_left) {
return arg0 << arg1;
} else if (intrin == Call::shift_right) {
return arg0 >> arg1;
} else if (intrin == Call::rounding_shift_left) {
return rounding_shift_left(arg0, arg1);
} else if (intrin == Call::rounding_shift_right) {
return rounding_shift_right(arg0, arg1);
}
Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
if (intrin == Call::mul_shift_right) {
return mul_shift_right(arg0, arg1, arg2);
} else if (intrin == Call::rounding_mul_shift_right) {
return rounding_mul_shift_right(arg0, arg1, arg2);
}
internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
return Expr();
}
constexpr static bool foldable = false;
HALIDE_ALWAYS_INLINE
Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
: intrin(intrin), args(args...) {
}
};
template<typename... Args>
std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
s << op.intrin << "(";
op.print_args(s);
s << ")";
return s;
}
template<typename... Args>
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
return {intrinsic_op, pattern_arg(args)...};
}
template<typename A, typename B>
auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_halving_add, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_halving_sub, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_shift_left, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
return {Call::rounding_shift_right, pattern_arg(a), pattern_arg(b)};
}
template<typename A, typename B, typename C>
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {Call::mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
}
template<typename A, typename B, typename C>
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
return {Call::rounding_mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
}
template<typename A>
struct NotOp {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::Not;
constexpr static IRNodeType max_node_type = IRNodeType::Not;
constexpr static bool canonical = A::canonical;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != IRNodeType::Not) {
return false;
}
const Not &op = (const Not &)e;
return (a.template match<bound>(*op.a.get(), state));
}
template<uint32_t bound, typename A2>
HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
return a.template match<bound>(unwrap(op.a), state);
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
return Not::make(a.make(state, type_hint));
}
constexpr static bool foldable = A::foldable;
template<typename A1 = A>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
a.make_folded_const(val, ty, state);
val.u.u64 = ~val.u.u64;
val.u.u64 &= 1;
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
assert_is_lvalue_if_expr<A>();
return IRMatcher::operator!(a);
}
template<typename A>
inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
s << "!(" << op.a << ")";
return s;
}
template<typename C, typename T, typename F>
struct SelectOp {
struct pattern_tag {};
C c;
T t;
F f;
constexpr static uint32_t binds = bindings<C>::mask | bindings<T>::mask | bindings<F>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::Select;
constexpr static IRNodeType max_node_type = IRNodeType::Select;
constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != Select::_node_type) {
return false;
}
const Select &op = (const Select &)e;
return (c.template match<bound>(*op.condition.get(), state) &&
t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
}
template<uint32_t bound, typename C2, typename T2, typename F2>
HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
return (c.template match<bound>(unwrap(instance.c), state) &&
t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
}
constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
template<typename C1 = C>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
halide_scalar_value_t c_val, t_val, f_val;
halide_type_t c_ty;
c.make_folded_const(c_val, c_ty, state);
if ((c_val.u.u64 & 1) == 1) {
t.make_folded_const(val, ty, state);
} else {
f.make_folded_const(val, ty, state);
}
ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
}
};
template<typename C, typename T, typename F>
std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
return s;
}
template<typename C, typename T, typename F>
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
assert_is_lvalue_if_expr<C>();
assert_is_lvalue_if_expr<T>();
assert_is_lvalue_if_expr<F>();
return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
}
template<typename A, typename B>
struct BroadcastOp {
struct pattern_tag {};
A a;
B lanes;
constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::Broadcast;
constexpr static IRNodeType max_node_type = IRNodeType::Broadcast;
constexpr static bool canonical = A::canonical && B::canonical;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type == Broadcast::_node_type) {
const Broadcast &op = (const Broadcast &)e;
if (a.template match<bound>(*op.value.get(), state) &&
lanes.template match<bound>(op.lanes, state)) {
return true;
}
}
return false;
}
template<uint32_t bound, typename A2, typename B2>
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
return (a.template match<bound>(unwrap(op.a), state) &&
lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t lanes_val;
halide_type_t ty;
lanes.make_folded_const(lanes_val, ty, state);
int32_t l = (int32_t)lanes_val.u.i64;
type_hint.lanes /= l;
Expr val = a.make(state, type_hint);
if (l == 1) {
return val;
} else {
return Broadcast::make(std::move(val), l);
}
}
constexpr static bool foldable = false;
template<typename A1 = A>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
halide_scalar_value_t lanes_val;
halide_type_t lanes_ty;
lanes.make_folded_const(lanes_val, lanes_ty, state);
uint16_t l = (uint16_t)lanes_val.u.i64;
a.make_folded_const(val, ty, state);
ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
}
};
template<typename A, typename B>
inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
s << "broadcast(" << op.a << ", " << op.lanes << ")";
return s;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), pattern_arg(lanes)};
}
template<typename A, typename B, typename C>
struct RampOp {
struct pattern_tag {};
A a;
B b;
C lanes;
constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask | bindings<C>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::Ramp;
constexpr static IRNodeType max_node_type = IRNodeType::Ramp;
constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != Ramp::_node_type) {
return false;
}
const Ramp &op = (const Ramp &)e;
if (a.template match<bound>(*op.base.get(), state) &&
b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
return true;
} else {
return false;
}
}
template<uint32_t bound, typename A2, typename B2, typename C2>
HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
return (a.template match<bound>(unwrap(op.a), state) &&
b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t lanes_val;
halide_type_t ty;
lanes.make_folded_const(lanes_val, ty, state);
int32_t l = (int32_t)lanes_val.u.i64;
type_hint.lanes /= l;
Expr ea, eb;
eb = b.make(state, type_hint);
ea = a.make(state, eb.type());
return Ramp::make(ea, eb, l);
}
constexpr static bool foldable = false;
};
template<typename A, typename B, typename C>
std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
return s;
}
template<typename A, typename B, typename C>
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
assert_is_lvalue_if_expr<A>();
assert_is_lvalue_if_expr<B>();
assert_is_lvalue_if_expr<C>();
return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
}
template<typename A, typename B, VectorReduce::Operator reduce_op>
struct VectorReduceOp {
struct pattern_tag {};
A a;
B lanes;
constexpr static uint32_t binds = bindings<A>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::VectorReduce;
constexpr static IRNodeType max_node_type = IRNodeType::VectorReduce;
constexpr static bool canonical = A::canonical;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type == VectorReduce::_node_type) {
const VectorReduce &op = (const VectorReduce &)e;
if (op.op == reduce_op &&
a.template match<bound>(*op.value.get(), state) &&
lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
return true;
}
}
return false;
}
template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp<A2, B2, reduce_op_2> &op, MatcherState &state) const noexcept {
return (reduce_op == reduce_op_2 &&
a.template match<bound>(unwrap(op.a), state) &&
lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t lanes_val;
halide_type_t ty;
lanes.make_folded_const(lanes_val, ty, state);
int l = (int)lanes_val.u.i64;
return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
}
constexpr static bool foldable = false;
};
template<typename A, typename B, VectorReduce::Operator reduce_op>
inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
return s;
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), pattern_arg(lanes)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), pattern_arg(lanes)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), pattern_arg(lanes)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), pattern_arg(lanes)};
}
template<typename A, typename B>
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), pattern_arg(lanes)};
}
template<typename A>
struct NegateOp {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::Sub;
constexpr static IRNodeType max_node_type = IRNodeType::Sub;
constexpr static bool canonical = A::canonical;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != Sub::_node_type) {
return false;
}
const Sub &op = (const Sub &)e;
return (a.template match<bound>(*op.b.get(), state) &&
is_const_zero(op.a));
}
template<uint32_t bound, typename A2>
HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
return a.template match<bound>(unwrap(p.a), state);
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
Expr ea = a.make(state, type_hint);
Expr z = make_zero(ea.type());
return Sub::make(std::move(z), std::move(ea));
}
constexpr static bool foldable = A::foldable;
template<typename A1 = A>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
a.make_folded_const(val, ty, state);
int dead_bits = 64 - ty.bits;
switch (ty.code) {
case halide_type_int:
if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
// Trying to negate the most negative signed int for a no-overflow type.
ty.lanes |= MatcherState::signed_integer_overflow;
} else {
// Negate, drop the high bits, and then sign-extend them back
val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
}
break;
case halide_type_uint:
val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
break;
case halide_type_float:
case halide_type_bfloat:
val.u.f64 = -val.u.f64;
break;
default:
// unreachable
;
}
}
};
template<typename A>
std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
s << "-" << op.a;
return s;
}
template<typename A>
HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
assert_is_lvalue_if_expr<A>();
return IRMatcher::operator-(a);
}
template<typename A>
struct CastOp {
struct pattern_tag {};
Type t;
A a;
constexpr static uint32_t binds = bindings<A>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::Cast;
constexpr static IRNodeType max_node_type = IRNodeType::Cast;
constexpr static bool canonical = A::canonical;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != Cast::_node_type) {
return false;
}
const Cast &op = (const Cast &)e;
return (e.type == t &&
a.template match<bound>(*op.value.get(), state));
}
template<uint32_t bound, typename A2>
HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
return t == op.t && a.template match<bound>(unwrap(op.a), state);
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
return cast(t, a.make(state, {}));
}
constexpr static bool foldable = false;
};
template<typename A>
std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
s << "cast(" << op.t << ", " << op.a << ")";
return s;
}
template<typename A>
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {t, pattern_arg(a)};
}
template<typename A>
struct Fold {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
constexpr static bool canonical = true;
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
halide_scalar_value_t c;
halide_type_t ty = type_hint;
a.make_folded_const(c, ty, state);
// The result of the fold may have an underspecified type
// (e.g. because it's from an int literal). Make the type code
// and bits match the required type, if there is one (we can
// tell from the bits field).
if (type_hint.bits) {
if (((int)ty.code == (int)halide_type_int) &&
((int)type_hint.code == (int)halide_type_float)) {
int64_t x = c.u.i64;
c.u.f64 = (double)x;
}
ty.code = type_hint.code;
ty.bits = type_hint.bits;
}
Expr e = make_const_expr(c, ty);
return e;
}
constexpr static bool foldable = A::foldable;
template<typename A1 = A>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
a.make_folded_const(val, ty, state);
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
s << "fold(" << op.a << ")";
return s;
}
template<typename A>
struct Overflows {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a predicate, so it always evaluates to a boolean,
// which has IRNodeType UIntImm
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = A::foldable;
template<typename A1 = A>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
a.make_folded_const(val, ty, state);
ty.code = halide_type_uint;
ty.bits = 64;
val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
ty.lanes = 1;
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
s << "overflows(" << op.a << ")";
return s;
}
struct Overflow {
struct pattern_tag {};
constexpr static uint32_t binds = 0;
// Overflow is an intrinsic, represented as a Call node
constexpr static IRNodeType min_node_type = IRNodeType::Call;
constexpr static IRNodeType max_node_type = IRNodeType::Call;
constexpr static bool canonical = true;
template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != Call::_node_type) {
return false;
}
const Call &op = (const Call &)e;
return (op.is_intrinsic(Call::signed_integer_overflow));
}
HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
type_hint.lanes |= MatcherState::signed_integer_overflow;
return make_const_special_expr(type_hint);
}
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
val.u.u64 = 0;
ty.lanes |= MatcherState::signed_integer_overflow;
}
};
inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
s << "overflow()";
return s;
}
template<typename A>
struct IsConst {
struct pattern_tag {};
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
A a;
constexpr static bool foldable = true;
template<typename A1 = A>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
Expr e = a.make(state, {});
ty.code = halide_type_uint;
ty.bits = 64;
ty.lanes = 1;
val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
s << "is_const(" << op.a << ")";
return s;
}
template<typename A, typename Prover>
struct CanProve {
struct pattern_tag {};
A a;
Prover *prover; // An existing simplifying mutator
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = true;
// Includes a raw call to an inlined make method, so don't inline.
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
Expr condition = a.make(state, {});
condition = prover->mutate(condition, nullptr);
val.u.u64 = is_const_one(condition);
ty.code = halide_type_uint;
ty.bits = 1;
ty.lanes = condition.type().lanes();
}
};
template<typename A, typename Prover>
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), p};
}
template<typename A, typename Prover>
std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
s << "can_prove(" << op.a << ")";
return s;
}
template<typename A>
struct IsFloat {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
Type t = a.make(state, {}).type();
val.u.u64 = t.is_float();
ty.code = halide_type_uint;
ty.bits = 1;
ty.lanes = t.lanes();
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
s << "is_float(" << op.a << ")";
return s;
}
template<typename A>
struct IsInt {
struct pattern_tag {};
A a;
int bits;
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
Type t = a.make(state, {}).type();
val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits);
ty.code = halide_type_uint;
ty.bits = 1;
ty.lanes = t.lanes();
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), bits};
}
template<typename A>
std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
s << "is_int(" << op.a;
if (op.bits > 0) {
s << ", " << op.bits;
}
s << ")";
return s;
}
template<typename A>
struct IsUInt {
struct pattern_tag {};
A a;
int bits;
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
Type t = a.make(state, {}).type();
val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits);
ty.code = halide_type_uint;
ty.bits = 1;
ty.lanes = t.lanes();
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a), bits};
}
template<typename A>
std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
s << "is_uint(" << op.a;
if (op.bits > 0) {
s << ", " << op.bits;
}
s << ")";
return s;
}
template<typename A>
struct IsScalar {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
Type t = a.make(state, {}).type();
val.u.u64 = t.is_scalar();
ty.code = halide_type_uint;
ty.bits = 1;
ty.lanes = t.lanes();
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
struct IsMaxValue {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
a.make_folded_const(val, ty, state);
const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
if (ty.code == halide_type_uint || ty.code == halide_type_int) {
val.u.u64 = (val.u.u64 == max_bits);
} else {
val.u.u64 = 0;
}
ty.code = halide_type_uint;
ty.bits = 1;
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
struct IsMinValue {
struct pattern_tag {};
A a;
constexpr static uint32_t binds = bindings<A>::mask;
// This rule is a boolean-valued predicate. Bools have type UIntImm.
constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
constexpr static bool canonical = true;
constexpr static bool foldable = true;
HALIDE_ALWAYS_INLINE
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
a.make_folded_const(val, ty, state);
if (ty.code == halide_type_int) {
const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
val.u.u64 = (val.u.u64 == min_bits);
} else if (ty.code == halide_type_uint) {
val.u.u64 = (val.u.u64 == 0);
} else {
val.u.u64 = 0;
}
ty.code = halide_type_uint;
ty.bits = 1;
}
};
template<typename A>
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}
template<typename A>
std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
s << "is_scalar(" << op.a << ")";
return s;
}
// Verify properties of each rewrite rule. Currently just fuzz tests them.
template<typename Before,
typename After,
typename Predicate,
typename = typename std::enable_if<std::decay<Before>::type::foldable &&
std::decay<After>::type::foldable>::type>
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
halide_type_t wildcard_type, halide_type_t output_type) noexcept {
// We only validate the rules in the scalar case
wildcard_type.lanes = output_type.lanes = 1;
// Track which types this rule has been tested for before
static std::set<uint32_t> tested;
if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
return;
}
// Print it in a form where it can be piped into a python/z3 validator
debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
// Substitute some random constants into the before and after
// expressions and see if the rule holds true. This should catch
// silly errors, but not necessarily corner cases.
static std::mt19937_64 rng(0);
MatcherState state;
Expr exprs[max_wild];
for (int trials = 0; trials < 100; trials++) {
// We want to test small constants more frequently than
// large ones, otherwise we'll just get coverage of
// overflow rules.
int shift = (int)(rng() & (wildcard_type.bits - 1));
for (int i = 0; i < max_wild; i++) {
// Bind all the exprs and constants
switch (wildcard_type.code) {
case halide_type_uint: {
// Normalize to the type's range by adding zero
uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
state.set_bound_const(i, val, wildcard_type);
val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
exprs[i] = make_const(wildcard_type, val);
state.set_binding(i, *exprs[i].get());
} break;
case halide_type_int: {
int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
state.set_bound_const(i, val, wildcard_type);
val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
exprs[i] = make_const(wildcard_type, val);
} break;
case halide_type_float:
case halide_type_bfloat: {
// Use a very narrow range of precise floats, so
// that none of the rules a human is likely to
// write have instabilities.
double val = ((int64_t)(rng() & 15) - 8) / 2.0;
state.set_bound_const(i, val, wildcard_type);
val = ((int64_t)(rng() & 15) - 8) / 2.0;
exprs[i] = make_const(wildcard_type, val);
} break;
default:
return; // Don't care about handles
}
state.set_binding(i, *exprs[i].get());
}
halide_scalar_value_t val_pred, val_before, val_after;
halide_type_t type = output_type;
if (!evaluate_predicate(pred, state)) {
continue;
}
before.make_folded_const(val_before, type, state);
uint16_t lanes = type.lanes;
after.make_folded_const(val_after, type, state);
lanes |= type.lanes;
if (lanes & MatcherState::special_values_mask) {
continue;
}
bool ok = true;
switch (output_type.code) {
case halide_type_uint:
// Compare normalized representations
ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
break;
case halide_type_int:
ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
break;
case halide_type_float:
case halide_type_bfloat: {
double error = std::abs(val_before.u.f64 - val_after.u.f64);
// We accept an equal bit pattern (e.g. inf vs inf),
// a small floating point difference, or turning a nan into not-a-nan.
ok &= (error < 0.01 ||
val_before.u.u64 == val_after.u.u64 ||
std::isnan(val_before.u.f64));
break;
}
default:
return;
}
if (!ok) {
debug(0) << "Fails with values:\n";
for (int i = 0; i < max_wild; i++) {
halide_scalar_value_t val;
state.get_bound_const(i, val, wildcard_type);
debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
}
for (int i = 0; i < max_wild; i++) {
debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
}
debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
internal_error;
}
}
}
template<typename Before,
typename After,
typename Predicate,
typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
std::decay<After>::type::foldable)>::type>
HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
halide_type_t, halide_type_t, int dummy = 0) noexcept {
// We can't verify rewrite rules that can't be constant-folded.
}
HALIDE_ALWAYS_INLINE
bool evaluate_predicate(bool x, MatcherState &) noexcept {
return x;
}
template<typename Pattern,
typename = typename enable_if_pattern<Pattern>::type>
HALIDE_ALWAYS_INLINE bool evaluate_predicate(Pattern p, MatcherState &state) {
halide_scalar_value_t c;
halide_type_t ty = halide_type_of<bool>();
p.make_folded_const(c, ty, state);
// Overflow counts as a failed predicate
return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
}
// #defines for testing
// Print all successful or failed matches
#define HALIDE_DEBUG_MATCHED_RULES 0
#define HALIDE_DEBUG_UNMATCHED_RULES 0
// Set to true if you want to fuzz test every rewrite passed to
// operator() to ensure the input and the output have the same value
// for lots of random values of the wildcards. Run
// correctness_simplify with this on.
#define HALIDE_FUZZ_TEST_RULES 0
template<typename Instance>
struct Rewriter {
Instance instance;
Expr result;
MatcherState state;
halide_type_t output_type, wildcard_type;
bool validate;
HALIDE_ALWAYS_INLINE
Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
: instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
}
template<typename After>
HALIDE_NEVER_INLINE void build_replacement(After after) {
result = after.make(state, output_type);
}
template<typename Before,
typename After,
typename = typename enable_if_pattern<Before>::type,
typename = typename enable_if_pattern<After>::type>
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
#if HALIDE_FUZZ_TEST_RULES
fuzz_test_rule(before, after, true, wildcard_type, output_type);
#endif
if (before.template match<0>(unwrap(instance), state)) {
#if HALIDE_DEBUG_MATCHED_RULES
debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
#endif
build_replacement(after);
return true;
} else {
#if HALIDE_DEBUG_UNMATCHED_RULES
debug(0) << instance << " does not match " << before << "\n";
#endif
return false;
}
}
template<typename Before,
typename = typename enable_if_pattern<Before>::type>
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
if (before.template match<0>(unwrap(instance), state)) {
result = after;
#if HALIDE_DEBUG_MATCHED_RULES
debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
#endif
return true;
} else {
#if HALIDE_DEBUG_UNMATCHED_RULES
debug(0) << instance << " does not match " << before << "\n";
#endif
return false;
}
}
template<typename Before,
typename = typename enable_if_pattern<Before>::type>
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
#if HALIDE_FUZZ_TEST_RULES
fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
#endif
if (before.template match<0>(unwrap(instance), state)) {
result = make_const(output_type, after);
#if HALIDE_DEBUG_MATCHED_RULES
debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
#endif
return true;
} else {
#if HALIDE_DEBUG_UNMATCHED_RULES
debug(0) << instance << " does not match " << before << "\n";
#endif
return false;
}
}
template<typename Before,
typename After,
typename Predicate,
typename = typename enable_if_pattern<Before>::type,
typename = typename enable_if_pattern<After>::type,
typename = typename enable_if_pattern<Predicate>::type>
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
#if HALIDE_FUZZ_TEST_RULES
fuzz_test_rule(before, after, pred, wildcard_type, output_type);
#endif
if (before.template match<0>(unwrap(instance), state) &&
evaluate_predicate(pred, state)) {
#if HALIDE_DEBUG_MATCHED_RULES
debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
#endif
build_replacement(after);
return true;
} else {
#if HALIDE_DEBUG_UNMATCHED_RULES
debug(0) << instance << " does not match " << before << "\n";
#endif
return false;
}
}
template<typename Before,
typename Predicate,
typename = typename enable_if_pattern<Before>::type,
typename = typename enable_if_pattern<Predicate>::type>
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
if (before.template match<0>(unwrap(instance), state) &&
evaluate_predicate(pred, state)) {
result = after;
#if HALIDE_DEBUG_MATCHED_RULES
debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
#endif
return true;
} else {
#if HALIDE_DEBUG_UNMATCHED_RULES
debug(0) << instance << " does not match " << before << "\n";
#endif
return false;
}
}
template<typename Before,
typename Predicate,
typename = typename enable_if_pattern<Before>::type,
typename = typename enable_if_pattern<Predicate>::type>
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
#if HALIDE_FUZZ_TEST_RULES
fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
#endif
if (before.template match<0>(unwrap(instance), state) &&
evaluate_predicate(pred, state)) {
result = make_const(output_type, after);
#if HALIDE_DEBUG_MATCHED_RULES
debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
#endif
return true;
} else {
#if HALIDE_DEBUG_UNMATCHED_RULES
debug(0) << instance << " does not match " << before << "\n";
#endif
return false;
}
}
};
/** Construct a rewriter for the given instance, which may be a pattern
* with concrete expressions as leaves, or just an expression. The
* second optional argument (wildcard_type) is a hint as to what the
* type of the wildcards is likely to be. If omitted it uses the same
* type as the expression itself. They are not required to be this
* type, but the rule will only be tested for wildcards of that type
* when testing is enabled.
*
* The rewriter can be used to check to see if the instance is one of
* some number of patterns and if so rewrite it into another form,
* using its operator() method. See Simplify.cpp for a bunch of
* example usage.
*
* Important: Any Exprs in patterns are captured by reference, not by
* value, so ensure they outlive the rewriter.
*/
// @{
template<typename Instance,
typename = typename enable_if_pattern<Instance>::type>
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
return {pattern_arg(instance), output_type, wildcard_type};
}
template<typename Instance,
typename = typename enable_if_pattern<Instance>::type>
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
return {pattern_arg(instance), output_type, output_type};
}
HALIDE_ALWAYS_INLINE
auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
return {pattern_arg(e), e.type(), wildcard_type};
}
HALIDE_ALWAYS_INLINE
auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
return {pattern_arg(e), e.type(), e.type()};
}
// @}
} // namespace IRMatcher
} // namespace Internal
} // namespace Halide
#endif