#ifndef HALIDE_IR_MATCH_H #define HALIDE_IR_MATCH_H /** \file * Defines a method to match a fragment of IR against a pattern containing wildcards */ #include #include #include #include #include "IR.h" #include "IREquality.h" #include "IROperator.h" namespace Halide { namespace Internal { /** Does the first expression have the same structure as the second? * Variables in the first expression with the name * are interpreted * as wildcards, and their matching equivalent in the second * expression is placed in the vector give as the third argument. * Wildcards require the types to match. For the type bits and width, * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit * integer vectors of any width (including scalars), and a UInt(0, 0) * will match any unsigned integer type. * * For example: \code Expr x = Variable::make(Int(32), "*"); match(x + x, 3 + (2*k), result) \endcode * should return true, and set result[0] to 3 and * result[1] to 2*k. */ bool expr_match(const Expr &pattern, const Expr &expr, std::vector &result); /** Does the first expression have the same structure as the second? * Variables are matched consistently. The first time a variable is * matched, it assumes the value of the matching part of the second * expression. Subsequent matches must be equal to the first match. * * For example: \code Var x("x"), y("y"); match(x*(x + y), a*(a + b), result) \endcode * should return true, and set result["x"] = a, and result["y"] = b. */ bool expr_match(const Expr &pattern, const Expr &expr, std::map &result); /** Rewrite the expression x to have `lanes` lanes. This is useful * for substituting the results of expr_match into a pattern expression. */ Expr with_lanes(const Expr &x, int lanes); void expr_match_test(); /** An alternative template-metaprogramming approach to expression * matching. Potentially more efficient. We lift the expression * pattern into a type, and then use force-inlined functions to * generate efficient matching and reconstruction code for any * pattern. Pattern elements are either one of the classes in the * namespace IRMatcher, or are non-null Exprs (represented as * BaseExprNode &). * * Pattern elements that are fully specified by their pattern can be * built into an expression using the make method. Some patterns, * such as a broadcast that matches any number of lanes, don't have * enough information to recreate an Expr. */ namespace IRMatcher { constexpr int max_wild = 6; static const halide_type_t i64_type = {halide_type_int, 64, 1}; /** To save stack space, the matcher objects are largely stateless and * immutable. This state object is built up during matching and then * consumed when constructing a replacement Expr. */ struct MatcherState { const BaseExprNode *bindings[max_wild]; halide_scalar_value_t bound_const[max_wild]; // values of the lanes field with special meaning. static constexpr uint16_t signed_integer_overflow = 0x8000; static constexpr uint16_t special_values_mask = 0x8000; // currently only one halide_type_t bound_const_type[max_wild]; HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept { bindings[i] = &n; } HALIDE_ALWAYS_INLINE const BaseExprNode *get_binding(int i) const noexcept { return bindings[i]; } HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept { bound_const[i].u.i64 = s; bound_const_type[i] = t; } HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept { bound_const[i].u.u64 = u; bound_const_type[i] = t; } HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept { bound_const[i].u.f64 = f; bound_const_type[i] = t; } HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept { bound_const[i] = val; bound_const_type[i] = t; } HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept { val = bound_const[i]; type = bound_const_type[i]; } HALIDE_ALWAYS_INLINE // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch MatcherState() noexcept { } }; template::type::pattern_tag> struct enable_if_pattern { struct type {}; }; template struct bindings { constexpr static uint32_t mask = std::remove_reference::type::binds; }; inline HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty) { const uint16_t flags = ty.lanes & MatcherState::special_values_mask; ty.lanes &= ~MatcherState::special_values_mask; if (flags & MatcherState::signed_integer_overflow) { return make_signed_integer_overflow(ty); } // unreachable return Expr(); } HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty) { halide_type_t scalar_type = ty; if (scalar_type.lanes & MatcherState::special_values_mask) { return make_const_special_expr(scalar_type); } const int lanes = scalar_type.lanes; scalar_type.lanes = 1; Expr e; switch (scalar_type.code) { case halide_type_int: e = IntImm::make(scalar_type, val.u.i64); break; case halide_type_uint: e = UIntImm::make(scalar_type, val.u.u64); break; case halide_type_float: case halide_type_bfloat: e = FloatImm::make(scalar_type, val.u.f64); break; default: // Unreachable return Expr(); } if (lanes > 1) { e = Broadcast::make(e, lanes); } return e; } bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept; // A fast version of expression equality that assumes a well-typed non-null expression tree. HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept { // Early out return (&a == &b) || ((a.type == b.type) && (a.node_type == b.node_type) && equal_helper(a, b)); } // A pattern that matches a specific expression struct SpecificExpr { struct pattern_tag {}; constexpr static uint32_t binds = 0; // What is the weakest and strongest IR node this could possibly be constexpr static IRNodeType min_node_type = IRNodeType::IntImm; constexpr static IRNodeType max_node_type = IRNodeType::Shuffle; constexpr static bool canonical = true; const BaseExprNode &expr; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { return equal(expr, e); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { return Expr(&expr); } constexpr static bool foldable = false; }; inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) { s << Expr(&e.expr); return s; } template struct WildConstInt { struct pattern_tag {}; constexpr static uint32_t binds = 1 << i; constexpr static IRNodeType min_node_type = IRNodeType::IntImm; constexpr static IRNodeType max_node_type = IRNodeType::IntImm; constexpr static bool canonical = true; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index"); const BaseExprNode *op = &e; if (op->node_type == IRNodeType::Broadcast) { op = ((const Broadcast *)op)->value.get(); } if (op->node_type != IRNodeType::IntImm) { return false; } int64_t value = ((const IntImm *)op)->value; if (bound & binds) { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return (halide_type_t)e.type == type && value == val.u.i64; } state.set_bound_const(i, value, e.type); return true; } template HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept { static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index"); if (bound & binds) { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return type == i64_type && value == val.u.i64; } state.set_bound_const(i, value, i64_type); return true; } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return make_const_expr(val, type); } constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { state.get_bound_const(i, val, ty); } }; template std::ostream &operator<<(std::ostream &s, const WildConstInt &c) { s << "ci" << i; return s; } template struct WildConstUInt { struct pattern_tag {}; constexpr static uint32_t binds = 1 << i; constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index"); const BaseExprNode *op = &e; if (op->node_type == IRNodeType::Broadcast) { op = ((const Broadcast *)op)->value.get(); } if (op->node_type != IRNodeType::UIntImm) { return false; } uint64_t value = ((const UIntImm *)op)->value; if (bound & binds) { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return (halide_type_t)e.type == type && value == val.u.u64; } state.set_bound_const(i, value, e.type); return true; } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return make_const_expr(val, type); } constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { state.get_bound_const(i, val, ty); } }; template std::ostream &operator<<(std::ostream &s, const WildConstUInt &c) { s << "cu" << i; return s; } template struct WildConstFloat { struct pattern_tag {}; constexpr static uint32_t binds = 1 << i; constexpr static IRNodeType min_node_type = IRNodeType::FloatImm; constexpr static IRNodeType max_node_type = IRNodeType::FloatImm; constexpr static bool canonical = true; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index"); const BaseExprNode *op = &e; if (op->node_type == IRNodeType::Broadcast) { op = ((const Broadcast *)op)->value.get(); } if (op->node_type != IRNodeType::FloatImm) { return false; } double value = ((const FloatImm *)op)->value; if (bound & binds) { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return (halide_type_t)e.type == type && value == val.u.f64; } state.set_bound_const(i, value, e.type); return true; } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return make_const_expr(val, type); } constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { state.get_bound_const(i, val, ty); } }; template std::ostream &operator<<(std::ostream &s, const WildConstFloat &c) { s << "cf" << i; return s; } // Matches and binds to any constant Expr. Does not support constant-folding. template struct WildConst { struct pattern_tag {}; constexpr static uint32_t binds = 1 << i; constexpr static IRNodeType min_node_type = IRNodeType::IntImm; constexpr static IRNodeType max_node_type = IRNodeType::FloatImm; constexpr static bool canonical = true; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index"); const BaseExprNode *op = &e; if (op->node_type == IRNodeType::Broadcast) { op = ((const Broadcast *)op)->value.get(); } switch (op->node_type) { case IRNodeType::IntImm: return WildConstInt().template match(e, state); case IRNodeType::UIntImm: return WildConstUInt().template match(e, state); case IRNodeType::FloatImm: return WildConstFloat().template match(e, state); default: return false; } } template HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept { static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index"); return WildConstInt().template match(e, state); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t val; halide_type_t type; state.get_bound_const(i, val, type); return make_const_expr(val, type); } constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { state.get_bound_const(i, val, ty); } }; template std::ostream &operator<<(std::ostream &s, const WildConst &c) { s << "c" << i; return s; } // Matches and binds to any Expr template struct Wild { struct pattern_tag {}; constexpr static uint32_t binds = 1 << (i + 16); constexpr static IRNodeType min_node_type = IRNodeType::IntImm; constexpr static IRNodeType max_node_type = StrongestExprNodeType; constexpr static bool canonical = true; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (bound & binds) { return equal(*state.get_binding(i), e); } state.set_binding(i, e); return true; } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { return state.get_binding(i); } constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { const auto *e = state.get_binding(i); ty = e->type; switch (e->node_type) { case IRNodeType::UIntImm: val.u.u64 = ((const UIntImm *)e)->value; return; case IRNodeType::IntImm: val.u.i64 = ((const IntImm *)e)->value; return; case IRNodeType::FloatImm: val.u.f64 = ((const FloatImm *)e)->value; return; default: // The function is noexcept, so silent failure. You // shouldn't be calling this if you haven't already // checked it's going to be a constant (e.g. with // is_const, or because you manually bound a constant Expr // to the state). val.u.u64 = 0; } } }; template std::ostream &operator<<(std::ostream &s, const Wild &op) { s << "_" << i; return s; } // Matches a specific constant or broadcast of that constant. The // constant must be representable as an int64_t. struct IntLiteral { struct pattern_tag {}; int64_t v; constexpr static uint32_t binds = 0; constexpr static IRNodeType min_node_type = IRNodeType::IntImm; constexpr static IRNodeType max_node_type = IRNodeType::FloatImm; constexpr static bool canonical = true; HALIDE_ALWAYS_INLINE explicit IntLiteral(int64_t v) : v(v) { } template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { const BaseExprNode *op = &e; if (e.node_type == IRNodeType::Broadcast) { op = ((const Broadcast *)op)->value.get(); } switch (op->node_type) { case IRNodeType::IntImm: return ((const IntImm *)op)->value == (int64_t)v; case IRNodeType::UIntImm: return ((const UIntImm *)op)->value == (uint64_t)v; case IRNodeType::FloatImm: return ((const FloatImm *)op)->value == (double)v; default: return false; } } template HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept { return v == val; } template HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept { return v == b.v; } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { return make_const(type_hint, v); } constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { // Assume type is already correct switch (ty.code) { case halide_type_int: val.u.i64 = v; break; case halide_type_uint: val.u.u64 = (uint64_t)v; break; case halide_type_float: case halide_type_bfloat: val.u.f64 = (double)v; break; default: // Unreachable ; } } }; HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t) { return t.v; } // Convert a provided pattern, expr, or constant int into the internal // representation we use in the matcher trees. template::type::pattern_tag> HALIDE_ALWAYS_INLINE T pattern_arg(T t) { return t; } HALIDE_ALWAYS_INLINE IntLiteral pattern_arg(int64_t x) { return IntLiteral{x}; } template HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr() { static_assert(!std::is_same::type, Expr>::value || std::is_lvalue_reference::value, "Exprs are captured by reference by IRMatcher objects and so must be lvalues"); } HALIDE_ALWAYS_INLINE SpecificExpr pattern_arg(const Expr &e) { return {*e.get()}; } // Helpers to deref SpecificExprs to const BaseExprNode & rather than // passing them by value anywhere (incurring lots of refcounting) template::type::pattern_tag, // But T may not be SpecificExpr typename = typename std::enable_if::type, SpecificExpr>::value>::type> HALIDE_ALWAYS_INLINE T unwrap(T t) { return t; } HALIDE_ALWAYS_INLINE const BaseExprNode &unwrap(const SpecificExpr &e) { return e.expr; } inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) { s << op.v; return s; } template int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept; template uint64_t constant_fold_bin_op(halide_type_t &, uint64_t, uint64_t) noexcept; template double constant_fold_bin_op(halide_type_t &, double, double) noexcept; constexpr bool commutative(IRNodeType t) { return (t == IRNodeType::Add || t == IRNodeType::Mul || t == IRNodeType::And || t == IRNodeType::Or || t == IRNodeType::Min || t == IRNodeType::Max || t == IRNodeType::EQ || t == IRNodeType::NE); } // Matches one of the binary operators template struct BinOp { struct pattern_tag {}; A a; B b; constexpr static uint32_t binds = bindings::mask | bindings::mask; constexpr static IRNodeType min_node_type = Op::_node_type; constexpr static IRNodeType max_node_type = Op::_node_type; // For commutative bin ops, we expect the weaker IR node type on // the right. That is, for the rule to be canonical it must be // possible that A is at least as strong as B. constexpr static bool canonical = A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type)); template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != Op::_node_type) { return false; } const Op &op = (const Op &)e; return (a.template match(*op.a.get(), state) && b.template match::mask>(*op.b.get(), state)); } template HALIDE_ALWAYS_INLINE bool match(const BinOp &op, MatcherState &state) const noexcept { return (std::is_same::value && a.template match(unwrap(op.a), state) && b.template match::mask>(unwrap(op.b), state)); } constexpr static bool foldable = A::foldable && B::foldable; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { halide_scalar_value_t val_a, val_b; if (std::is_same::value) { b.make_folded_const(val_b, ty, state); if ((std::is_same::value && val_b.u.u64 == 0) || (std::is_same::value && val_b.u.u64 == 1)) { // Short circuit val = val_b; return; } const uint16_t l = ty.lanes; a.make_folded_const(val_a, ty, state); ty.lanes |= l; // Make sure the overflow bits are sticky } else { a.make_folded_const(val_a, ty, state); if ((std::is_same::value && val_a.u.u64 == 0) || (std::is_same::value && val_a.u.u64 == 1)) { // Short circuit val = val_a; return; } const uint16_t l = ty.lanes; b.make_folded_const(val_b, ty, state); ty.lanes |= l; } switch (ty.code) { case halide_type_int: val.u.i64 = constant_fold_bin_op(ty, val_a.u.i64, val_b.u.i64); break; case halide_type_uint: val.u.u64 = constant_fold_bin_op(ty, val_a.u.u64, val_b.u.u64); break; case halide_type_float: case halide_type_bfloat: val.u.f64 = constant_fold_bin_op(ty, val_a.u.f64, val_b.u.f64); break; default: // unreachable ; } } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept { Expr ea, eb; if (std::is_same::value) { eb = b.make(state, type_hint); ea = a.make(state, eb.type()); } else { ea = a.make(state, type_hint); eb = b.make(state, ea.type()); } // We sometimes mix vectors and scalars in the rewrite rules, // so insert a broadcast if necessary. if (ea.type().is_vector() && !eb.type().is_vector()) { eb = Broadcast::make(eb, ea.type().lanes()); } if (eb.type().is_vector() && !ea.type().is_vector()) { ea = Broadcast::make(ea, eb.type().lanes()); } return Op::make(std::move(ea), std::move(eb)); } }; template uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept; template uint64_t constant_fold_cmp_op(uint64_t, uint64_t) noexcept; template uint64_t constant_fold_cmp_op(double, double) noexcept; // Matches one of the comparison operators template struct CmpOp { struct pattern_tag {}; A a; B b; constexpr static uint32_t binds = bindings::mask | bindings::mask; constexpr static IRNodeType min_node_type = Op::_node_type; constexpr static IRNodeType max_node_type = Op::_node_type; constexpr static bool canonical = (A::canonical && B::canonical && (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) && (Op::_node_type != IRNodeType::GE) && (Op::_node_type != IRNodeType::GT)); template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != Op::_node_type) { return false; } const Op &op = (const Op &)e; return (a.template match(*op.a.get(), state) && b.template match::mask>(*op.b.get(), state)); } template HALIDE_ALWAYS_INLINE bool match(const CmpOp &op, MatcherState &state) const noexcept { return (std::is_same::value && a.template match(unwrap(op.a), state) && b.template match::mask>(unwrap(op.b), state)); } constexpr static bool foldable = A::foldable && B::foldable; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { halide_scalar_value_t val_a, val_b; // If one side is an untyped const, evaluate the other side first to get a type hint. if (std::is_same::value) { b.make_folded_const(val_b, ty, state); const uint16_t l = ty.lanes; a.make_folded_const(val_a, ty, state); ty.lanes |= l; } else { a.make_folded_const(val_a, ty, state); const uint16_t l = ty.lanes; b.make_folded_const(val_b, ty, state); ty.lanes |= l; } switch (ty.code) { case halide_type_int: val.u.u64 = constant_fold_cmp_op(val_a.u.i64, val_b.u.i64); break; case halide_type_uint: val.u.u64 = constant_fold_cmp_op(val_a.u.u64, val_b.u.u64); break; case halide_type_float: case halide_type_bfloat: val.u.u64 = constant_fold_cmp_op(val_a.u.f64, val_b.u.f64); break; default: // unreachable ; } ty.code = halide_type_uint; ty.bits = 1; } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { // If one side is an untyped const, evaluate the other side first to get a type hint. Expr ea, eb; if (std::is_same::value) { eb = b.make(state, {}); ea = a.make(state, eb.type()); } else { ea = a.make(state, {}); eb = b.make(state, ea.type()); } // We sometimes mix vectors and scalars in the rewrite rules, // so insert a broadcast if necessary. if (ea.type().is_vector() && !eb.type().is_vector()) { eb = Broadcast::make(eb, ea.type().lanes()); } if (eb.type().is_vector() && !ea.type().is_vector()) { ea = Broadcast::make(ea, eb.type().lanes()); } return Op::make(std::move(ea), std::move(eb)); } }; template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "(" << op.a << " + " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "(" << op.a << " - " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "(" << op.a << " * " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "(" << op.a << " / " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "(" << op.a << " && " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "(" << op.a << " || " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "min(" << op.a << ", " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "max(" << op.a << ", " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const CmpOp &op) { s << "(" << op.a << " <= " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const CmpOp &op) { s << "(" << op.a << " < " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const CmpOp &op) { s << "(" << op.a << " >= " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const CmpOp &op) { s << "(" << op.a << " > " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const CmpOp &op) { s << "(" << op.a << " == " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const CmpOp &op) { s << "(" << op.a << " != " << op.b << ")"; return s; } template std::ostream &operator<<(std::ostream &s, const BinOp &op) { s << "(" << op.a << " % " << op.b << ")"; return s; } template HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return IRMatcher::operator+(a, b); } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0; int dead_bits = 64 - t.bits; // Drop the high bits then sign-extend them back return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { uint64_t ones = (uint64_t)(-1); return (a + b) & (ones >> (64 - t.bits)); } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { return a + b; } template HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return IRMatcher::operator-(a, b); } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0; // Drop the high bits then sign-extend them back int dead_bits = 64 - t.bits; return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { uint64_t ones = (uint64_t)(-1); return (a - b) & (ones >> (64 - t.bits)); } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { return a - b; } template HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return IRMatcher::operator*(a, b); } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0; int dead_bits = 64 - t.bits; // Drop the high bits then sign-extend them back return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { uint64_t ones = (uint64_t)(-1); return (a * b) & (ones >> (64 - t.bits)); } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { return a * b; } template HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) { return IRMatcher::operator/(a, b); } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op
(halide_type_t &t, int64_t a, int64_t b) noexcept { return div_imp(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op
(halide_type_t &t, uint64_t a, uint64_t b) noexcept { return div_imp(a, b); } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op
(halide_type_t &t, double a, double b) noexcept { return div_imp(a, b); } template HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return IRMatcher::operator%(a, b); } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { return mod_imp(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { return mod_imp(a, b); } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { return mod_imp(a, b); } template HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(b)}; } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { return std::min(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { return std::min(a, b); } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { return std::min(a, b); } template HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(std::forward(a)), pattern_arg(std::forward(b))}; } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { return std::max(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { return std::max(a, b); } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { return std::max(a, b); } template HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) { return IRMatcher::operator<(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(int64_t a, int64_t b) noexcept { return a < b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(uint64_t a, uint64_t b) noexcept { return a < b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(double a, double b) noexcept { return a < b; } template HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) { return IRMatcher::operator>(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(int64_t a, int64_t b) noexcept { return a > b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(uint64_t a, uint64_t b) noexcept { return a > b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(double a, double b) noexcept { return a > b; } template HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) { return IRMatcher::operator<=(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(int64_t a, int64_t b) noexcept { return a <= b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(uint64_t a, uint64_t b) noexcept { return a <= b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(double a, double b) noexcept { return a <= b; } template HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) { return IRMatcher::operator>=(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(int64_t a, int64_t b) noexcept { return a >= b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(uint64_t a, uint64_t b) noexcept { return a >= b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(double a, double b) noexcept { return a >= b; } template HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) { return IRMatcher::operator==(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(int64_t a, int64_t b) noexcept { return a == b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(uint64_t a, uint64_t b) noexcept { return a == b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(double a, double b) noexcept { return a == b; } template HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) { return IRMatcher::operator!=(a, b); } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(int64_t a, int64_t b) noexcept { return a != b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(uint64_t a, uint64_t b) noexcept { return a != b; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op(double a, double b) noexcept { return a != b; } template HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) { return IRMatcher::operator||(a, b); } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { return (a | b) & 1; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { return (a | b) & 1; } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { // Unreachable, as it would be a type mismatch. return 0; } template HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp { return {pattern_arg(a), pattern_arg(b)}; } template HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) { return IRMatcher::operator&&(a, b); } template<> HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op(halide_type_t &t, int64_t a, int64_t b) noexcept { return a & b & 1; } template<> HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op(halide_type_t &t, uint64_t a, uint64_t b) noexcept { return a & b & 1; } template<> HALIDE_ALWAYS_INLINE double constant_fold_bin_op(halide_type_t &t, double a, double b) noexcept { // Unreachable return 0; } constexpr inline uint32_t bitwise_or_reduce() { return 0; } template constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) { return first | bitwise_or_reduce(rest...); } constexpr inline bool and_reduce() { return true; } template constexpr bool and_reduce(bool first, Args... rest) { return first && and_reduce(rest...); } // TODO: this can be replaced with std::min() once we require C++14 or later constexpr int const_min(int a, int b) { return a < b ? a : b; } template struct Intrin { struct pattern_tag {}; Call::IntrinsicOp intrin; std::tuple args; static constexpr uint32_t binds = bitwise_or_reduce((bindings::mask)...); constexpr static IRNodeType min_node_type = IRNodeType::Call; constexpr static IRNodeType max_node_type = IRNodeType::Call; constexpr static bool canonical = and_reduce((Args::canonical)...); template::type> HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept { using T = decltype(std::get(args)); return (std::get(args).template match(*c.args[i].get(), state) && match_args::mask>(0, c, state)); } template HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept { return true; } template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != IRNodeType::Call) { return false; } const Call &c = (const Call &)e; return (c.is_intrinsic(intrin) && match_args<0, bound>(0, c, state)); } template::type> HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const { s << std::get(args); if (i + 1 < sizeof...(Args)) { s << ", "; } print_args(0, s); } template HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const { } HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const { print_args<0>(0, s); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { Expr arg0 = std::get<0>(args).make(state, type_hint); if (intrin == Call::likely) { return likely(arg0); } else if (intrin == Call::likely_if_innermost) { return likely_if_innermost(arg0); } else if (intrin == Call::abs) { return abs(arg0); } Expr arg1 = std::get(args).make(state, type_hint); if (intrin == Call::absd) { return absd(arg0, arg1); } else if (intrin == Call::widening_add) { return widening_add(arg0, arg1); } else if (intrin == Call::widening_sub) { return widening_sub(arg0, arg1); } else if (intrin == Call::widening_mul) { return widening_mul(arg0, arg1); } else if (intrin == Call::saturating_add) { return saturating_add(arg0, arg1); } else if (intrin == Call::saturating_sub) { return saturating_sub(arg0, arg1); } else if (intrin == Call::halving_add) { return halving_add(arg0, arg1); } else if (intrin == Call::halving_sub) { return halving_sub(arg0, arg1); } else if (intrin == Call::rounding_halving_add) { return rounding_halving_add(arg0, arg1); } else if (intrin == Call::rounding_halving_sub) { return rounding_halving_sub(arg0, arg1); } else if (intrin == Call::shift_left) { return arg0 << arg1; } else if (intrin == Call::shift_right) { return arg0 >> arg1; } else if (intrin == Call::rounding_shift_left) { return rounding_shift_left(arg0, arg1); } else if (intrin == Call::rounding_shift_right) { return rounding_shift_right(arg0, arg1); } Expr arg2 = std::get(args).make(state, type_hint); if (intrin == Call::mul_shift_right) { return mul_shift_right(arg0, arg1, arg2); } else if (intrin == Call::rounding_mul_shift_right) { return rounding_mul_shift_right(arg0, arg1, arg2); } internal_error << "Unhandled intrinsic in IRMatcher: " << intrin; return Expr(); } constexpr static bool foldable = false; HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept : intrin(intrin), args(args...) { } }; template std::ostream &operator<<(std::ostream &s, const Intrin &op) { s << op.intrin << "("; op.print_args(s); s << ")"; return s; } template HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin { return {intrinsic_op, pattern_arg(args)...}; } template auto widening_add(A &&a, B &&b) noexcept -> Intrin { return {Call::widening_add, pattern_arg(a), pattern_arg(b)}; } template auto widening_sub(A &&a, B &&b) noexcept -> Intrin { return {Call::widening_sub, pattern_arg(a), pattern_arg(b)}; } template auto widening_mul(A &&a, B &&b) noexcept -> Intrin { return {Call::widening_mul, pattern_arg(a), pattern_arg(b)}; } template auto saturating_add(A &&a, B &&b) noexcept -> Intrin { return {Call::saturating_add, pattern_arg(a), pattern_arg(b)}; } template auto saturating_sub(A &&a, B &&b) noexcept -> Intrin { return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)}; } template auto halving_add(A &&a, B &&b) noexcept -> Intrin { return {Call::halving_add, pattern_arg(a), pattern_arg(b)}; } template auto halving_sub(A &&a, B &&b) noexcept -> Intrin { return {Call::halving_sub, pattern_arg(a), pattern_arg(b)}; } template auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin { return {Call::rounding_halving_add, pattern_arg(a), pattern_arg(b)}; } template auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin { return {Call::rounding_halving_sub, pattern_arg(a), pattern_arg(b)}; } template auto shift_left(A &&a, B &&b) noexcept -> Intrin { return {Call::shift_left, pattern_arg(a), pattern_arg(b)}; } template auto shift_right(A &&a, B &&b) noexcept -> Intrin { return {Call::shift_right, pattern_arg(a), pattern_arg(b)}; } template auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin { return {Call::rounding_shift_left, pattern_arg(a), pattern_arg(b)}; } template auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin { return {Call::rounding_shift_right, pattern_arg(a), pattern_arg(b)}; } template auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin { return {Call::mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)}; } template auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin { return {Call::rounding_mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)}; } template struct NotOp { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::Not; constexpr static IRNodeType max_node_type = IRNodeType::Not; constexpr static bool canonical = A::canonical; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != IRNodeType::Not) { return false; } const Not &op = (const Not &)e; return (a.template match(*op.a.get(), state)); } template HALIDE_ALWAYS_INLINE bool match(const NotOp &op, MatcherState &state) const noexcept { return a.template match(unwrap(op.a), state); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { return Not::make(a.make(state, type_hint)); } constexpr static bool foldable = A::foldable; template HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { a.make_folded_const(val, ty, state); val.u.u64 = ~val.u.u64; val.u.u64 &= 1; } }; template HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) { assert_is_lvalue_if_expr(); return IRMatcher::operator!(a); } template inline std::ostream &operator<<(std::ostream &s, const NotOp &op) { s << "!(" << op.a << ")"; return s; } template struct SelectOp { struct pattern_tag {}; C c; T t; F f; constexpr static uint32_t binds = bindings::mask | bindings::mask | bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::Select; constexpr static IRNodeType max_node_type = IRNodeType::Select; constexpr static bool canonical = C::canonical && T::canonical && F::canonical; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != Select::_node_type) { return false; } const Select &op = (const Select &)e; return (c.template match(*op.condition.get(), state) && t.template match::mask>(*op.true_value.get(), state) && f.template match::mask | bindings::mask>(*op.false_value.get(), state)); } template HALIDE_ALWAYS_INLINE bool match(const SelectOp &instance, MatcherState &state) const noexcept { return (c.template match(unwrap(instance.c), state) && t.template match::mask>(unwrap(instance.t), state) && f.template match::mask | bindings::mask>(unwrap(instance.f), state)); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint)); } constexpr static bool foldable = C::foldable && T::foldable && F::foldable; template HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { halide_scalar_value_t c_val, t_val, f_val; halide_type_t c_ty; c.make_folded_const(c_val, c_ty, state); if ((c_val.u.u64 & 1) == 1) { t.make_folded_const(val, ty, state); } else { f.make_folded_const(val, ty, state); } ty.lanes |= c_ty.lanes & MatcherState::special_values_mask; } }; template std::ostream &operator<<(std::ostream &s, const SelectOp &op) { s << "select(" << op.c << ", " << op.t << ", " << op.f << ")"; return s; } template HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(c), pattern_arg(t), pattern_arg(f)}; } template struct BroadcastOp { struct pattern_tag {}; A a; B lanes; constexpr static uint32_t binds = bindings::mask | bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::Broadcast; constexpr static IRNodeType max_node_type = IRNodeType::Broadcast; constexpr static bool canonical = A::canonical && B::canonical; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type == Broadcast::_node_type) { const Broadcast &op = (const Broadcast &)e; if (a.template match(*op.value.get(), state) && lanes.template match(op.lanes, state)) { return true; } } return false; } template HALIDE_ALWAYS_INLINE bool match(const BroadcastOp &op, MatcherState &state) const noexcept { return (a.template match(unwrap(op.a), state) && lanes.template match::mask>(unwrap(op.lanes), state)); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t lanes_val; halide_type_t ty; lanes.make_folded_const(lanes_val, ty, state); int32_t l = (int32_t)lanes_val.u.i64; type_hint.lanes /= l; Expr val = a.make(state, type_hint); if (l == 1) { return val; } else { return Broadcast::make(std::move(val), l); } } constexpr static bool foldable = false; template HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { halide_scalar_value_t lanes_val; halide_type_t lanes_ty; lanes.make_folded_const(lanes_val, lanes_ty, state); uint16_t l = (uint16_t)lanes_val.u.i64; a.make_folded_const(val, ty, state); ty.lanes = l | (ty.lanes & MatcherState::special_values_mask); } }; template inline std::ostream &operator<<(std::ostream &s, const BroadcastOp &op) { s << "broadcast(" << op.a << ", " << op.lanes << ")"; return s; } template HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp { assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(lanes)}; } template struct RampOp { struct pattern_tag {}; A a; B b; C lanes; constexpr static uint32_t binds = bindings::mask | bindings::mask | bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::Ramp; constexpr static IRNodeType max_node_type = IRNodeType::Ramp; constexpr static bool canonical = A::canonical && B::canonical && C::canonical; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != Ramp::_node_type) { return false; } const Ramp &op = (const Ramp &)e; if (a.template match(*op.base.get(), state) && b.template match::mask>(*op.stride.get(), state) && lanes.template match::mask | bindings::mask>(op.lanes, state)) { return true; } else { return false; } } template HALIDE_ALWAYS_INLINE bool match(const RampOp &op, MatcherState &state) const noexcept { return (a.template match(unwrap(op.a), state) && b.template match::mask>(unwrap(op.b), state) && lanes.template match::mask | bindings::mask>(unwrap(op.lanes), state)); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t lanes_val; halide_type_t ty; lanes.make_folded_const(lanes_val, ty, state); int32_t l = (int32_t)lanes_val.u.i64; type_hint.lanes /= l; Expr ea, eb; eb = b.make(state, type_hint); ea = a.make(state, eb.type()); return Ramp::make(ea, eb, l); } constexpr static bool foldable = false; }; template std::ostream &operator<<(std::ostream &s, const RampOp &op) { s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")"; return s; } template HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp { assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(b), pattern_arg(c)}; } template struct VectorReduceOp { struct pattern_tag {}; A a; B lanes; constexpr static uint32_t binds = bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::VectorReduce; constexpr static IRNodeType max_node_type = IRNodeType::VectorReduce; constexpr static bool canonical = A::canonical; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type == VectorReduce::_node_type) { const VectorReduce &op = (const VectorReduce &)e; if (op.op == reduce_op && a.template match(*op.value.get(), state) && lanes.template match::mask>(op.type.lanes(), state)) { return true; } } return false; } template HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp &op, MatcherState &state) const noexcept { return (reduce_op == reduce_op_2 && a.template match(unwrap(op.a), state) && lanes.template match::mask>(unwrap(op.lanes), state)); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { halide_scalar_value_t lanes_val; halide_type_t ty; lanes.make_folded_const(lanes_val, ty, state); int l = (int)lanes_val.u.i64; return VectorReduce::make(reduce_op, a.make(state, type_hint), l); } constexpr static bool foldable = false; }; template inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp &op) { s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")"; return s; } template HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp { assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(lanes)}; } template HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp { assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(lanes)}; } template HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp { assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(lanes)}; } template HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp { assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(lanes)}; } template HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp { assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(lanes)}; } template struct NegateOp { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::Sub; constexpr static IRNodeType max_node_type = IRNodeType::Sub; constexpr static bool canonical = A::canonical; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != Sub::_node_type) { return false; } const Sub &op = (const Sub &)e; return (a.template match(*op.b.get(), state) && is_const_zero(op.a)); } template HALIDE_ALWAYS_INLINE bool match(NegateOp &&p, MatcherState &state) const noexcept { return a.template match(unwrap(p.a), state); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { Expr ea = a.make(state, type_hint); Expr z = make_zero(ea.type()); return Sub::make(std::move(z), std::move(ea)); } constexpr static bool foldable = A::foldable; template HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { a.make_folded_const(val, ty, state); int dead_bits = 64 - ty.bits; switch (ty.code) { case halide_type_int: if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) { // Trying to negate the most negative signed int for a no-overflow type. ty.lanes |= MatcherState::signed_integer_overflow; } else { // Negate, drop the high bits, and then sign-extend them back val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits; } break; case halide_type_uint: val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits; break; case halide_type_float: case halide_type_bfloat: val.u.f64 = -val.u.f64; break; default: // unreachable ; } } }; template std::ostream &operator<<(std::ostream &s, const NegateOp &op) { s << "-" << op.a; return s; } template HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) { assert_is_lvalue_if_expr(); return IRMatcher::operator-(a); } template struct CastOp { struct pattern_tag {}; Type t; A a; constexpr static uint32_t binds = bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::Cast; constexpr static IRNodeType max_node_type = IRNodeType::Cast; constexpr static bool canonical = A::canonical; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != Cast::_node_type) { return false; } const Cast &op = (const Cast &)e; return (e.type == t && a.template match(*op.value.get(), state)); } template HALIDE_ALWAYS_INLINE bool match(const CastOp &op, MatcherState &state) const noexcept { return t == op.t && a.template match(unwrap(op.a), state); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { return cast(t, a.make(state, {})); } constexpr static bool foldable = false; }; template std::ostream &operator<<(std::ostream &s, const CastOp &op) { s << "cast(" << op.t << ", " << op.a << ")"; return s; } template HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp { assert_is_lvalue_if_expr(); return {t, pattern_arg(a)}; } template struct Fold { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; constexpr static IRNodeType min_node_type = IRNodeType::IntImm; constexpr static IRNodeType max_node_type = IRNodeType::FloatImm; constexpr static bool canonical = true; HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept { halide_scalar_value_t c; halide_type_t ty = type_hint; a.make_folded_const(c, ty, state); // The result of the fold may have an underspecified type // (e.g. because it's from an int literal). Make the type code // and bits match the required type, if there is one (we can // tell from the bits field). if (type_hint.bits) { if (((int)ty.code == (int)halide_type_int) && ((int)type_hint.code == (int)halide_type_float)) { int64_t x = c.u.i64; c.u.f64 = (double)x; } ty.code = type_hint.code; ty.bits = type_hint.bits; } Expr e = make_const_expr(c, ty); return e; } constexpr static bool foldable = A::foldable; template HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { a.make_folded_const(val, ty, state); } }; template HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template std::ostream &operator<<(std::ostream &s, const Fold &op) { s << "fold(" << op.a << ")"; return s; } template struct Overflows { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; // This rule is a predicate, so it always evaluates to a boolean, // which has IRNodeType UIntImm constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = A::foldable; template HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { a.make_folded_const(val, ty, state); ty.code = halide_type_uint; ty.bits = 64; val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0; ty.lanes = 1; } }; template HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template std::ostream &operator<<(std::ostream &s, const Overflows &op) { s << "overflows(" << op.a << ")"; return s; } struct Overflow { struct pattern_tag {}; constexpr static uint32_t binds = 0; // Overflow is an intrinsic, represented as a Call node constexpr static IRNodeType min_node_type = IRNodeType::Call; constexpr static IRNodeType max_node_type = IRNodeType::Call; constexpr static bool canonical = true; template HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { if (e.node_type != Call::_node_type) { return false; } const Call &op = (const Call &)e; return (op.is_intrinsic(Call::signed_integer_overflow)); } HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const { type_hint.lanes |= MatcherState::signed_integer_overflow; return make_const_special_expr(type_hint); } constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { val.u.u64 = 0; ty.lanes |= MatcherState::signed_integer_overflow; } }; inline std::ostream &operator<<(std::ostream &s, const Overflow &op) { s << "overflow()"; return s; } template struct IsConst { struct pattern_tag {}; constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; A a; bool check_v; int64_t v; constexpr static bool foldable = true; template HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept { Expr e = a.make(state, {}); ty.code = halide_type_uint; ty.bits = 64; ty.lanes = 1; if (check_v) { val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0; } else { val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0; } } }; template HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst { assert_is_lvalue_if_expr(); return {pattern_arg(a), false, 0}; } template HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst { assert_is_lvalue_if_expr(); return {pattern_arg(a), true, value}; } template std::ostream &operator<<(std::ostream &s, const IsConst &op) { if (op.check_v) { s << "is_const(" << op.a << ")"; } else { s << "is_const(" << op.a << ", " << op.v << ")"; } return s; } template struct CanProve { struct pattern_tag {}; A a; Prover *prover; // An existing simplifying mutator constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = true; // Includes a raw call to an inlined make method, so don't inline. HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { Expr condition = a.make(state, {}); condition = prover->mutate(condition, nullptr); val.u.u64 = is_const_one(condition); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = condition.type().lanes(); } }; template HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve { assert_is_lvalue_if_expr(); return {pattern_arg(a), p}; } template std::ostream &operator<<(std::ostream &s, const CanProve &op) { s << "can_prove(" << op.a << ")"; return s; } template struct IsFloat { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_float(); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); } }; template HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template std::ostream &operator<<(std::ostream &s, const IsFloat &op) { s << "is_float(" << op.a << ")"; return s; } template struct IsInt { struct pattern_tag {}; A a; int bits; constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); } }; template HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0) noexcept -> IsInt { assert_is_lvalue_if_expr(); return {pattern_arg(a), bits}; } template std::ostream &operator<<(std::ostream &s, const IsInt &op) { s << "is_int(" << op.a; if (op.bits > 0) { s << ", " << op.bits; } s << ")"; return s; } template struct IsUInt { struct pattern_tag {}; A a; int bits; constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); } }; template HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0) noexcept -> IsUInt { assert_is_lvalue_if_expr(); return {pattern_arg(a), bits}; } template std::ostream &operator<<(std::ostream &s, const IsUInt &op) { s << "is_uint(" << op.a; if (op.bits > 0) { s << ", " << op.bits; } s << ")"; return s; } template struct IsScalar { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); val.u.u64 = t.is_scalar(); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); } }; template HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template std::ostream &operator<<(std::ostream &s, const IsScalar &op) { s << "is_scalar(" << op.a << ")"; return s; } template struct IsMaxValue { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. a.make_folded_const(val, ty, state); const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int)); if (ty.code == halide_type_uint || ty.code == halide_type_int) { val.u.u64 = (val.u.u64 == max_bits); } else { val.u.u64 = 0; } ty.code = halide_type_uint; ty.bits = 1; } }; template HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template std::ostream &operator<<(std::ostream &s, const IsMaxValue &op) { s << "is_max_value(" << op.a << ")"; return s; } template struct IsMinValue { struct pattern_tag {}; A a; constexpr static uint32_t binds = bindings::mask; // This rule is a boolean-valued predicate. Bools have type UIntImm. constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; constexpr static bool canonical = true; constexpr static bool foldable = true; HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. a.make_folded_const(val, ty, state); if (ty.code == halide_type_int) { const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1); val.u.u64 = (val.u.u64 == min_bits); } else if (ty.code == halide_type_uint) { val.u.u64 = (val.u.u64 == 0); } else { val.u.u64 = 0; } ty.code = halide_type_uint; ty.bits = 1; } }; template HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template std::ostream &operator<<(std::ostream &s, const IsMinValue &op) { s << "is_min_value(" << op.a << ")"; return s; } // Verify properties of each rewrite rule. Currently just fuzz tests them. template::type::foldable && std::decay::type::foldable>::type> HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept { // We only validate the rules in the scalar case wildcard_type.lanes = output_type.lanes = 1; // Track which types this rule has been tested for before static std::set tested; if (!tested.insert(reinterpret_bits(wildcard_type)).second) { return; } // Print it in a form where it can be piped into a python/z3 validator debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n"; // Substitute some random constants into the before and after // expressions and see if the rule holds true. This should catch // silly errors, but not necessarily corner cases. static std::mt19937_64 rng(0); MatcherState state; Expr exprs[max_wild]; for (int trials = 0; trials < 100; trials++) { // We want to test small constants more frequently than // large ones, otherwise we'll just get coverage of // overflow rules. int shift = (int)(rng() & (wildcard_type.bits - 1)); for (int i = 0; i < max_wild; i++) { // Bind all the exprs and constants switch (wildcard_type.code) { case halide_type_uint: { // Normalize to the type's range by adding zero uint64_t val = constant_fold_bin_op(wildcard_type, (uint64_t)rng() >> shift, 0); state.set_bound_const(i, val, wildcard_type); val = constant_fold_bin_op(wildcard_type, (uint64_t)rng() >> shift, 0); exprs[i] = make_const(wildcard_type, val); state.set_binding(i, *exprs[i].get()); } break; case halide_type_int: { int64_t val = constant_fold_bin_op(wildcard_type, (int64_t)rng() >> shift, 0); state.set_bound_const(i, val, wildcard_type); val = constant_fold_bin_op(wildcard_type, (int64_t)rng() >> shift, 0); exprs[i] = make_const(wildcard_type, val); } break; case halide_type_float: case halide_type_bfloat: { // Use a very narrow range of precise floats, so // that none of the rules a human is likely to // write have instabilities. double val = ((int64_t)(rng() & 15) - 8) / 2.0; state.set_bound_const(i, val, wildcard_type); val = ((int64_t)(rng() & 15) - 8) / 2.0; exprs[i] = make_const(wildcard_type, val); } break; default: return; // Don't care about handles } state.set_binding(i, *exprs[i].get()); } halide_scalar_value_t val_pred, val_before, val_after; halide_type_t type = output_type; if (!evaluate_predicate(pred, state)) { continue; } before.make_folded_const(val_before, type, state); uint16_t lanes = type.lanes; after.make_folded_const(val_after, type, state); lanes |= type.lanes; if (lanes & MatcherState::special_values_mask) { continue; } bool ok = true; switch (output_type.code) { case halide_type_uint: // Compare normalized representations ok &= (constant_fold_bin_op(output_type, val_before.u.u64, 0) == constant_fold_bin_op(output_type, val_after.u.u64, 0)); break; case halide_type_int: ok &= (constant_fold_bin_op(output_type, val_before.u.i64, 0) == constant_fold_bin_op(output_type, val_after.u.i64, 0)); break; case halide_type_float: case halide_type_bfloat: { double error = std::abs(val_before.u.f64 - val_after.u.f64); // We accept an equal bit pattern (e.g. inf vs inf), // a small floating point difference, or turning a nan into not-a-nan. ok &= (error < 0.01 || val_before.u.u64 == val_after.u.u64 || std::isnan(val_before.u.f64)); break; } default: return; } if (!ok) { debug(0) << "Fails with values:\n"; for (int i = 0; i < max_wild; i++) { halide_scalar_value_t val; state.get_bound_const(i, val, wildcard_type); debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n"; } for (int i = 0; i < max_wild; i++) { debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n"; } debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n"; debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n"; debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n"; internal_error; } } } template::type::foldable && std::decay::type::foldable)>::type> HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t, halide_type_t, int dummy = 0) noexcept { // We can't verify rewrite rules that can't be constant-folded. } HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept { return x; } template::type> HALIDE_ALWAYS_INLINE bool evaluate_predicate(Pattern p, MatcherState &state) { halide_scalar_value_t c; halide_type_t ty = halide_type_of(); p.make_folded_const(c, ty, state); // Overflow counts as a failed predicate return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0); } // #defines for testing // Print all successful or failed matches #define HALIDE_DEBUG_MATCHED_RULES 0 #define HALIDE_DEBUG_UNMATCHED_RULES 0 // Set to true if you want to fuzz test every rewrite passed to // operator() to ensure the input and the output have the same value // for lots of random values of the wildcards. Run // correctness_simplify with this on. #define HALIDE_FUZZ_TEST_RULES 0 template struct Rewriter { Instance instance; Expr result; MatcherState state; halide_type_t output_type, wildcard_type; bool validate; HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt) : instance(std::move(instance)), output_type(ot), wildcard_type(wt) { } template HALIDE_NEVER_INLINE void build_replacement(After after) { result = after.make(state, output_type); } template::type, typename = typename enable_if_pattern::type> HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) { static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values"); static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form"); static_assert(After::canonical, "RHS of rewrite rule should be in canonical form"); #if HALIDE_FUZZ_TEST_RULES fuzz_test_rule(before, after, true, wildcard_type, output_type); #endif if (before.template match<0>(unwrap(instance), state)) { build_replacement(after); #if HALIDE_DEBUG_MATCHED_RULES debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n"; #endif return true; } else { #if HALIDE_DEBUG_UNMATCHED_RULES debug(0) << instance << " does not match " << before << "\n"; #endif return false; } } template::type> HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept { static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form"); if (before.template match<0>(unwrap(instance), state)) { result = after; #if HALIDE_DEBUG_MATCHED_RULES debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n"; #endif return true; } else { #if HALIDE_DEBUG_UNMATCHED_RULES debug(0) << instance << " does not match " << before << "\n"; #endif return false; } } template::type> HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept { static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form"); #if HALIDE_FUZZ_TEST_RULES fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type); #endif if (before.template match<0>(unwrap(instance), state)) { result = make_const(output_type, after); #if HALIDE_DEBUG_MATCHED_RULES debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n"; #endif return true; } else { #if HALIDE_DEBUG_UNMATCHED_RULES debug(0) << instance << " does not match " << before << "\n"; #endif return false; } } template::type, typename = typename enable_if_pattern::type, typename = typename enable_if_pattern::type> HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) { static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold"); static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values"); static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values"); static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form"); static_assert(After::canonical, "RHS of rewrite rule should be in canonical form"); #if HALIDE_FUZZ_TEST_RULES fuzz_test_rule(before, after, pred, wildcard_type, output_type); #endif if (before.template match<0>(unwrap(instance), state) && evaluate_predicate(pred, state)) { build_replacement(after); #if HALIDE_DEBUG_MATCHED_RULES debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n"; #endif return true; } else { #if HALIDE_DEBUG_UNMATCHED_RULES debug(0) << instance << " does not match " << before << "\n"; #endif return false; } } template::type, typename = typename enable_if_pattern::type> HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) { static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold"); static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form"); if (before.template match<0>(unwrap(instance), state) && evaluate_predicate(pred, state)) { result = after; #if HALIDE_DEBUG_MATCHED_RULES debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n"; #endif return true; } else { #if HALIDE_DEBUG_UNMATCHED_RULES debug(0) << instance << " does not match " << before << "\n"; #endif return false; } } template::type, typename = typename enable_if_pattern::type> HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) { static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold"); static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form"); #if HALIDE_FUZZ_TEST_RULES fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type); #endif if (before.template match<0>(unwrap(instance), state) && evaluate_predicate(pred, state)) { result = make_const(output_type, after); #if HALIDE_DEBUG_MATCHED_RULES debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n"; #endif return true; } else { #if HALIDE_DEBUG_UNMATCHED_RULES debug(0) << instance << " does not match " << before << "\n"; #endif return false; } } }; /** Construct a rewriter for the given instance, which may be a pattern * with concrete expressions as leaves, or just an expression. The * second optional argument (wildcard_type) is a hint as to what the * type of the wildcards is likely to be. If omitted it uses the same * type as the expression itself. They are not required to be this * type, but the rule will only be tested for wildcards of that type * when testing is enabled. * * The rewriter can be used to check to see if the instance is one of * some number of patterns and if so rewrite it into another form, * using its operator() method. See Simplify.cpp for a bunch of * example usage. * * Important: Any Exprs in patterns are captured by reference, not by * value, so ensure they outlive the rewriter. */ // @{ template::type> HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter { return {pattern_arg(instance), output_type, wildcard_type}; } template::type> HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter { return {pattern_arg(instance), output_type, output_type}; } HALIDE_ALWAYS_INLINE auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter { return {pattern_arg(e), e.type(), wildcard_type}; } HALIDE_ALWAYS_INLINE auto rewriter(const Expr &e) noexcept -> Rewriter { return {pattern_arg(e), e.type(), e.type()}; } // @} } // namespace IRMatcher } // namespace Internal } // namespace Halide #endif