#include "IREquality.h" #include "IROperator.h" #include "IRVisitor.h" namespace Halide { namespace Internal { using std::string; using std::vector; namespace { /** The class that does the work of comparing two IR nodes. */ class IRComparer : public IRVisitor { public: /** Different possible results of a comparison. Unknown should * only occur internally due to a cache miss. */ enum CmpResult { Unknown, Equal, LessThan, GreaterThan }; /** The result of the comparison. Should be Equal, LessThan, or GreaterThan. */ CmpResult result = Equal; /** Compare two expressions or statements and return the * result. Returns the result immediately if it is already * non-zero. */ // @{ CmpResult compare_expr(const Expr &a, const Expr &b); CmpResult compare_stmt(const Stmt &a, const Stmt &b); // @} /** If the expressions you're comparing may contain many repeated * subexpressions, it's worth passing in a cache to use. * Currently this is only done in common-subexpression * elimination. */ IRComparer(IRCompareCache *c = nullptr) : cache(c) { } private: Expr expr; Stmt stmt; IRCompareCache *cache; CmpResult compare_names(const std::string &a, const std::string &b); CmpResult compare_types(Type a, Type b); CmpResult compare_expr_vector(const std::vector &a, const std::vector &b); // Compare two things that already have a well-defined operator< template CmpResult compare_scalar(T a, T b); void visit(const IntImm *) override; void visit(const UIntImm *) override; void visit(const FloatImm *) override; void visit(const StringImm *) override; void visit(const Cast *) override; void visit(const Reinterpret *) override; void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; void visit(const Mul *) override; void visit(const Div *) override; void visit(const Mod *) override; void visit(const Min *) override; void visit(const Max *) override; void visit(const EQ *) override; void visit(const NE *) override; void visit(const LT *) override; void visit(const LE *) override; void visit(const GT *) override; void visit(const GE *) override; void visit(const And *) override; void visit(const Or *) override; void visit(const Not *) override; void visit(const Select *) override; void visit(const Load *) override; void visit(const Ramp *) override; void visit(const Broadcast *) override; void visit(const Call *) override; void visit(const Let *) override; void visit(const LetStmt *) override; void visit(const AssertStmt *) override; void visit(const ProducerConsumer *) override; void visit(const For *) override; void visit(const Acquire *) override; void visit(const Store *) override; void visit(const Provide *) override; void visit(const Allocate *) override; void visit(const Free *) override; void visit(const Realize *) override; void visit(const Block *) override; void visit(const Fork *) override; void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; void visit(const VectorReduce *) override; }; template IRComparer::CmpResult IRComparer::compare_scalar(T a, T b) { if (result != Equal) { return result; } if constexpr (std::is_floating_point_v) { // NaNs are equal to each other and less than non-nans if (std::isnan(a) && std::isnan(b)) { result = Equal; return result; } if (std::isnan(a)) { result = LessThan; return result; } if (std::isnan(b)) { result = GreaterThan; return result; } } if (a < b) { result = LessThan; } else if (a > b) { result = GreaterThan; } return result; } IRComparer::CmpResult IRComparer::compare_expr(const Expr &a, const Expr &b) { if (result != Equal) { return result; } if (a.same_as(b)) { result = Equal; return result; } // Undefined values are equal to each other and less than defined values if (!a.defined() && !b.defined()) { result = Equal; return result; } if (!a.defined()) { result = LessThan; return result; } if (!b.defined()) { result = GreaterThan; return result; } // If in the future we have hashes for Exprs, this is a good place // to compare the hashes: // if (compare_scalar(a.hash(), b.hash()) != Equal) { // return result; // } if (compare_scalar(a->node_type, b->node_type) != Equal) { return result; } if (compare_types(a.type(), b.type()) != Equal) { return result; } // Check the cache - perhaps these exprs have already been compared and found equal. if (cache && cache->contains(a, b)) { result = Equal; return result; } expr = a; b.accept(this); if (cache && result == Equal) { cache->insert(a, b); } return result; } IRComparer::CmpResult IRComparer::compare_stmt(const Stmt &a, const Stmt &b) { if (result != Equal) { return result; } if (a.same_as(b)) { result = Equal; return result; } if (!a.defined() && !b.defined()) { result = Equal; return result; } if (!a.defined()) { result = LessThan; return result; } if (!b.defined()) { result = GreaterThan; return result; } if (compare_scalar(a->node_type, b->node_type) != Equal) { return result; } stmt = a; b.accept(this); return result; } IRComparer::CmpResult IRComparer::compare_types(Type a, Type b) { if (result != Equal) { return result; } compare_scalar(a.code(), b.code()); compare_scalar(a.bits(), b.bits()); compare_scalar(a.lanes(), b.lanes()); if (result != Equal) { return result; } const halide_handle_cplusplus_type *ha = a.handle_type; const halide_handle_cplusplus_type *hb = b.handle_type; if (ha == hb) { // Same handle type, or both not handles, or both void * return result; } if (ha == nullptr) { // void* < T* result = LessThan; return result; } if (hb == nullptr) { // T* > void* result = GreaterThan; return result; } // They're both non-void handle types with distinct type info // structs. We now need to distinguish between different C++ // pointer types (e.g. char * vs const float *). If would be nice // if the structs were unique per C++ type. Then comparing the // pointers above would be sufficient. Unfortunately, different // shared libraries in the same process each create a distinct // struct for the same type. We therefore have to do a deep // comparison of the type info fields. compare_scalar(ha->reference_type, hb->reference_type); compare_names(ha->inner_name.name, hb->inner_name.name); compare_scalar(ha->inner_name.cpp_type_type, hb->inner_name.cpp_type_type); compare_scalar(ha->namespaces.size(), hb->namespaces.size()); compare_scalar(ha->enclosing_types.size(), hb->enclosing_types.size()); compare_scalar(ha->cpp_type_modifiers.size(), hb->cpp_type_modifiers.size()); if (result != Equal) { return result; } for (size_t i = 0; i < ha->namespaces.size(); i++) { compare_names(ha->namespaces[i], hb->namespaces[i]); } if (result != Equal) { return result; } for (size_t i = 0; i < ha->enclosing_types.size(); i++) { compare_scalar(ha->enclosing_types[i].cpp_type_type, hb->enclosing_types[i].cpp_type_type); compare_names(ha->enclosing_types[i].name, hb->enclosing_types[i].name); } if (result != Equal) { return result; } for (size_t i = 0; i < ha->cpp_type_modifiers.size(); i++) { compare_scalar(ha->cpp_type_modifiers[i], hb->cpp_type_modifiers[i]); } return result; } IRComparer::CmpResult IRComparer::compare_names(const string &a, const string &b) { if (result != Equal) { return result; } int string_cmp = a.compare(b); if (string_cmp < 0) { result = LessThan; } else if (string_cmp > 0) { result = GreaterThan; } return result; } IRComparer::CmpResult IRComparer::compare_expr_vector(const vector &a, const vector &b) { if (result != Equal) { return result; } compare_scalar(a.size(), b.size()); for (size_t i = 0; (i < a.size()) && result == Equal; i++) { compare_expr(a[i], b[i]); } return result; } void IRComparer::visit(const IntImm *op) { const IntImm *e = expr.as(); compare_scalar(e->value, op->value); } void IRComparer::visit(const UIntImm *op) { const UIntImm *e = expr.as(); compare_scalar(e->value, op->value); } void IRComparer::visit(const FloatImm *op) { const FloatImm *e = expr.as(); compare_scalar(e->value, op->value); } void IRComparer::visit(const StringImm *op) { const StringImm *e = expr.as(); compare_names(e->value, op->value); } void IRComparer::visit(const Cast *op) { compare_expr(expr.as()->value, op->value); } void IRComparer::visit(const Reinterpret *op) { compare_expr(expr.as()->value, op->value); } void IRComparer::visit(const Variable *op) { const Variable *e = expr.as(); compare_names(e->name, op->name); } namespace { template void visit_binary_operator(IRComparer *cmp, const T *op, Expr expr) { const T *e = expr.as(); cmp->compare_expr(e->a, op->a); cmp->compare_expr(e->b, op->b); } } // namespace void IRComparer::visit(const Add *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Sub *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Mul *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Div *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Mod *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Min *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Max *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const EQ *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const NE *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const LT *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const LE *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const GT *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const GE *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const And *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Or *op) { visit_binary_operator(this, op, expr); } void IRComparer::visit(const Not *op) { const Not *e = expr.as(); compare_expr(e->a, op->a); } void IRComparer::visit(const Select *op) { const Select *e = expr.as