#include "AssociativeOpsTable.h" #include "IRPrinter.h" #include namespace Halide { namespace Internal { using std::map; using std::string; using std::vector; namespace { enum class RootExpr { Add = 0, Mul = 1, Max = 2, Min = 3, Sub = 4, Select = 5, And = 6, Or = 7, Cast = 8, Unknown = 9, // Not supported IR type }; enum class ValType { UInt1 = 0, UInt8 = 1, UInt16 = 2, UInt32 = 3, UInt64 = 4, Int8 = 5, Int16 = 6, Int32 = 7, Int64 = 8, Float16 = 9, Float32 = 10, Float64 = 11, All = 11, // General type (including all previous types) }; ValType convert_halide_type_to_val_type(const Type &halide_t) { internal_assert(halide_t.is_scalar() && !halide_t.is_handle()); ValType val_t; if (halide_t.is_uint()) { if (halide_t.bits() == 1) { // Bool val_t = ValType::UInt1; } else if (halide_t.bits() == 8) { val_t = ValType::UInt8; } else if (halide_t.bits() == 16) { val_t = ValType::UInt16; } else if (halide_t.bits() == 32) { val_t = ValType::UInt32; } else { internal_assert(halide_t.bits() == 64); val_t = ValType::UInt64; } } else if (halide_t.is_int()) { if (halide_t.bits() == 8) { val_t = ValType::Int8; } else if (halide_t.bits() == 16) { val_t = ValType::Int16; } else if (halide_t.bits() == 32) { val_t = ValType::Int32; } else { internal_assert(halide_t.bits() == 64); val_t = ValType::Int64; } } else { internal_assert(halide_t.is_float()); if (halide_t.bits() == 16) { val_t = ValType::Float16; } else if (halide_t.bits() == 32) { val_t = ValType::Float32; } else { internal_assert(halide_t.bits() == 64); val_t = ValType::Float64; } } return val_t; } vector convert_halide_types_to_val_types(const vector &halide_types) { vector val_types(halide_types.size()); for (size_t i = 0; i < halide_types.size(); ++i) { val_types[i] = convert_halide_type_to_val_type(halide_types[i]); } return val_types; } struct TableKey { vector types; RootExpr root; size_t dim; TableKey(ValType t, RootExpr r, size_t d) : types({t}), root(r), dim(d) { } TableKey(const vector &t, RootExpr r, size_t d) : types(t), root(r), dim(d) { } bool operator==(const TableKey &other) const { return (types == other.types) && (root == other.root) && (dim == other.dim); } bool operator<(const TableKey &other) const { if (types < other.types) { return true; } else if (types > other.types) { return false; } if (root < other.root) { return true; } else if (root > other.root) { return false; } return (dim < other.dim); } }; static map> pattern_tables; #define declare_vars(t, index) \ Expr x##index = Variable::make(t, "x" + std::to_string(index)); \ Expr y##index = Variable::make(t, "y" + std::to_string(index)); \ Expr k##index = Variable::make(t, "k" + std::to_string(index)); \ Expr zero_##index = make_const(t, 0); \ Expr one_##index = make_const(t, 1); \ Expr neg_one_##index = make_const(t, -1); \ Expr tmax_##index = t.max(); \ Expr tmin_##index = t.min(); #define declare_vars_single(types) \ internal_assert(types.size() == 1); \ declare_vars(types[0], 0) #define declare_vars_double(types) \ internal_assert(types.size() == 2); \ declare_vars(types[0], 0) \ declare_vars(types[1], 1) void populate_ops_table_single_general_add(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(x0 + y0, zero_0, true); } void populate_ops_table_single_general_mul(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(x0 * y0, one_0, true); } void populate_ops_table_single_general_max(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(max(x0, y0), tmin_0, true); } void populate_ops_table_single_general_min(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(min(x0, y0), tmax_0, true); } void populate_ops_table_single_general_sub(const vector &types, vector &table) { declare_vars_single(types); } void populate_ops_table_single_general_select(const vector &types, vector &table) { declare_vars_single(types); } void populate_ops_table_double_general_add(const vector &types, vector &table) { declare_vars_double(types); if (types[0] == types[1]) { table.push_back({{x0 + y0, x0 + y1}, {zero_0, zero_1}, true}); } } void populate_ops_table_double_general_mul(const vector &types, vector &table) { declare_vars_double(types); } void populate_ops_table_double_general_max(const vector &types, vector &table) { declare_vars_double(types); table.push_back({{max(x0, y0), select(y0 < x0, x1, y1)}, {tmin_0, zero_1}, true}); } void populate_ops_table_double_general_min(const vector &types, vector &table) { declare_vars_double(types); table.push_back({{min(x0, y0), select(x0 < y0, x1, y1)}, {tmax_0, zero_1}, true}); } void populate_ops_table_double_general_sub(const vector &types, vector &table) { declare_vars_double(types); if (types[0] == types[1]) { table.push_back({{x0 * y0 - x1 * y1, x1 * y0 + x0 * y1}, {one_0, zero_1}, true}); table.push_back({{x0 * y0 - y1 * x1, x1 * y0 + y1 * x0}, {one_0, zero_1}, true}); } } void populate_ops_table_double_general_select(const vector &types, vector &table) { declare_vars_double(types); } void populate_ops_table_single_uint1_and(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(x0 && y0, one_0, true); } void populate_ops_table_single_uint1_or(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(x0 || y0, zero_0, true); } void populate_ops_table_single_uint8_cast(const vector &types, vector &table) { declare_vars_single(types); Expr k0_uint16 = Variable::make(UInt(16), "k0"); Expr k0_uint32 = Variable::make(UInt(32), "k0"); Expr k0_uint64 = Variable::make(UInt(64), "k0"); table.emplace_back(cast(min(cast(x0 + y0), k0_uint16)), zero_0, true); table.emplace_back(cast(min(cast(x0 + y0), k0_uint32)), zero_0, true); table.emplace_back(cast(min(cast(x0 + y0), k0_uint64)), zero_0, true); } void populate_ops_table_single_uint8_select(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add } void populate_ops_table_single_uint16_cast(const vector &types, vector &table) { declare_vars_single(types); Expr k0_uint32 = Variable::make(UInt(32), "k0"); Expr k0_uint64 = Variable::make(UInt(64), "k0"); table.emplace_back(cast(min(cast(x0 + y0), k0_uint32)), zero_0, true); table.emplace_back(cast(min(cast(x0 + y0), k0_uint64)), zero_0, true); } void populate_ops_table_single_uint16_select(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add } void populate_ops_table_single_uint32_cast(const vector &types, vector &table) { declare_vars_single(types); Expr k0_uint64 = Variable::make(UInt(64), "k0"); table.emplace_back(cast(min(cast(x0 + y0), k0_uint64)), zero_0, true); } void populate_ops_table_single_uint32_select(const vector &types, vector &table) { declare_vars_single(types); table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating add table.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add } static const map &types, vector &)> val_type_to_populate_luts_fn = { {TableKey(ValType::All, RootExpr::Add, 1), &populate_ops_table_single_general_add}, {TableKey(ValType::All, RootExpr::Mul, 1), &populate_ops_table_single_general_mul}, {TableKey(ValType::All, RootExpr::Max, 1), &populate_ops_table_single_general_max}, {TableKey(ValType::All, RootExpr::Min, 1), &populate_ops_table_single_general_min}, {TableKey(ValType::All, RootExpr::Sub, 1), &populate_ops_table_single_general_sub}, {TableKey(ValType::All, RootExpr::Select, 1), &populate_ops_table_single_general_select}, {TableKey(ValType::All, RootExpr::Add, 2), &populate_ops_table_double_general_add}, {TableKey(ValType::All, RootExpr::Mul, 2), &populate_ops_table_double_general_mul}, {TableKey(ValType::All, RootExpr::Max, 2), &populate_ops_table_double_general_max}, {TableKey(ValType::All, RootExpr::Min, 2), &populate_ops_table_double_general_min}, {TableKey(ValType::All, RootExpr::Sub, 2), &populate_ops_table_double_general_sub}, {TableKey(ValType::All, RootExpr::Select, 2), &populate_ops_table_double_general_select}, {TableKey(ValType::UInt1, RootExpr::And, 1), &populate_ops_table_single_uint1_and}, {TableKey(ValType::UInt1, RootExpr::Or, 1), &populate_ops_table_single_uint1_or}, {TableKey(ValType::UInt8, RootExpr::Cast, 1), &populate_ops_table_single_uint8_cast}, {TableKey(ValType::UInt8, RootExpr::Select, 1), &populate_ops_table_single_uint8_select}, {TableKey(ValType::UInt16, RootExpr::Cast, 1), &populate_ops_table_single_uint16_cast}, {TableKey(ValType::UInt16, RootExpr::Select, 1), &populate_ops_table_single_uint16_select}, {TableKey(ValType::UInt32, RootExpr::Cast, 1), &populate_ops_table_single_uint32_cast}, {TableKey(ValType::UInt32, RootExpr::Select, 1), &populate_ops_table_single_uint32_select}, }; const vector &get_ops_table_helper(const vector &types, RootExpr root, size_t dim) { TableKey gen_key(ValType::All, root, dim); TableKey key(convert_halide_types_to_val_types(types), root, dim); const auto &table_it = pattern_tables.find(key); if (table_it == pattern_tables.end()) { // Populate the table if we haven't done so previously vector &table = pattern_tables[key]; // Populate the general associative op LUT const auto &gen_iter = val_type_to_populate_luts_fn.find(gen_key); if (gen_iter != val_type_to_populate_luts_fn.end()) { gen_iter->second(types, table); } // Populate the type-specific associative op LUT const auto &iter = val_type_to_populate_luts_fn.find(key); if (iter != val_type_to_populate_luts_fn.end()) { iter->second(types, table); } return table; } return table_it->second; } std::string print_types(const vector &types) { std::ostringstream stream; stream << "{"; for (size_t i = 0; i < types.size(); ++i) { if (i > 0) { stream << ", "; } stream << types[i]; } stream << "}"; return stream.str(); } } // anonymous namespace const vector &get_ops_table(const vector &exprs) { internal_assert(!exprs.empty()); static const vector empty; if (exprs.size() > 2) { debug(5) << "Returning empty table since tuple size is larger than 2\n"; return empty; } vector types(exprs.size()); for (size_t i = 0; i < exprs.size(); ++i) { types[i] = exprs[i].type(); } RootExpr root = RootExpr::Unknown; if (exprs[0].as()) { debug(5) << "Returning Add root table for type " << print_types(types) << "\n"; root = RootExpr::Add; } else if (exprs[0].as()) { debug(5) << "Returning Sub root table for type " << print_types(types) << "\n"; root = RootExpr::Sub; } else if (exprs[0].as()) { debug(5) << "Returning Mul root table for type " << print_types(types) << "\n"; root = RootExpr::Mul; } else if (exprs[0].as()) { debug(5) << "Returning Min root table for type " << print_types(types) << "\n"; root = RootExpr::Min; } else if (exprs[0].as()) { debug(5) << "Returning Max root table for type " << print_types(types) << "\n"; root = RootExpr::Max; } else if (exprs[0].as()) { debug(5) << "Returning Select root table for type " << print_types(types) << "\n"; root = RootExpr::Select; } else if (exprs[0].as()) { debug(5) << "Returning And root table for type " << print_types(types) << "\n"; root = RootExpr::And; } else if (exprs[0].as()) { debug(5) << "Returning Or root table for type " << print_types(types) << "\n"; root = RootExpr::Or; } else if (exprs[0].as()) { debug(5) << "Returning Cast root table for type " << print_types(types) << "\n"; root = RootExpr::Cast; } if (root != RootExpr::Unknown) { // get_ops_table_helper() lazily initializes the table, so ensure // that multiple threads can't try to do so at the same time. static std::mutex ops_table_lock; std::lock_guard lock_guard(ops_table_lock); const vector &table = get_ops_table_helper(types, root, exprs.size()); debug(7) << "Table size: " << table.size() << "\n"; for (const auto &p : table) { debug(7) << p; } return table; } debug(5) << "Returning empty table\n"; return empty; } } // namespace Internal } // namespace Halide