CodeGen_Vulkan_Dev.cpp
#include <algorithm>
#include <fstream> // for dump to file
#include <sstream>
#include <unordered_set>
#include "CSE.h"
#include "CodeGen_GPU_Dev.h"
#include "CodeGen_Internal.h"
#include "CodeGen_Vulkan_Dev.h"
#include "Debug.h"
#include "Deinterleave.h"
#include "FindIntrinsics.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Scope.h"
#include "Simplify.h"
#include "SpirvIR.h"
#include "Target.h"
#ifdef WITH_SPIRV
namespace Halide {
namespace Internal {
namespace { // anonymous
// --
class CodeGen_Vulkan_Dev : public CodeGen_GPU_Dev {
public:
CodeGen_Vulkan_Dev(Target target);
/** Compile a GPU kernel into the module. This may be called many times
* with different kernels, which will all be accumulated into a single
* source module shared by a given Halide pipeline. */
void add_kernel(Stmt stmt,
const std::string &name,
const std::vector<DeviceArgument> &args) override;
/** (Re)initialize the GPU kernel module. This is separate from compile,
* since a GPU device module will often have many kernels compiled into it
* for a single pipeline. */
void init_module() override;
std::vector<char> compile_to_src() override;
std::string get_current_kernel_name() override;
void dump() override;
std::string print_gpu_name(const std::string &name) override;
std::string api_unique_name() override {
return "vulkan";
}
protected:
class SPIRV_Emitter : public IRVisitor {
public:
SPIRV_Emitter(Target t);
using IRVisitor::visit;
void visit(const IntImm *) override;
void visit(const UIntImm *) override;
void visit(const FloatImm *) override;
void visit(const StringImm *) override;
void visit(const Cast *) override;
void visit(const Reinterpret *) override;
void visit(const Variable *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
void visit(const Mul *) override;
void visit(const Div *) override;
void visit(const Mod *) override;
void visit(const Min *) override;
void visit(const Max *) override;
void visit(const EQ *) override;
void visit(const NE *) override;
void visit(const LT *) override;
void visit(const LE *) override;
void visit(const GT *) override;
void visit(const GE *) override;
void visit(const And *) override;
void visit(const Or *) override;
void visit(const Not *) override;
void visit(const Select *) override;
void visit(const Load *) override;
void visit(const Ramp *) override;
void visit(const Broadcast *) override;
void visit(const Call *) override;
void visit(const Let *) override;
void visit(const LetStmt *) override;
void visit(const AssertStmt *) override;
void visit(const For *) override;
void visit(const Store *) override;
void visit(const Provide *) override;
void visit(const Allocate *) override;
void visit(const Free *) override;
void visit(const Realize *) override;
void visit(const ProducerConsumer *op) override;
void visit(const IfThenElse *) override;
void visit(const Evaluate *) override;
void visit(const Shuffle *) override;
void visit(const VectorReduce *) override;
void visit(const Prefetch *) override;
void visit(const Fork *) override;
void visit(const Acquire *) override;
void visit(const Atomic *) override;
void reset();
// Top-level function for adding kernels
void add_kernel(const Stmt &s, const std::string &name, const std::vector<DeviceArgument> &args);
void init_module();
void compile(std::vector<char> &binary);
void dump() const;
// Encode the descriptor sets into a sidecar which will be added
// as a header to the module prior to the actual SPIR-V binary
void encode_header(SpvBinary &spirv_header);
// Scalarize expressions
void scalarize(const Expr &e);
SpvId map_type_to_pair(const Type &t);
// Workgroup size
void reset_workgroup_size();
void find_workgroup_size(const Stmt &s);
void declare_workgroup_size(SpvId kernel_func_id);
void declare_entry_point(const Stmt &s, SpvId kernel_func_id);
void declare_device_args(const Stmt &s, uint32_t entry_point_index, const std::string &kernel_name, const std::vector<DeviceArgument> &args);
// Common operator visitors
void visit_unary_op(SpvOp op_code, Type t, const Expr &a);
void visit_binary_op(SpvOp op_code, Type t, const Expr &a, const Expr &b);
void visit_glsl_op(SpvId glsl_op_code, Type t, const std::vector<Expr> &args);
void load_from_scalar_index(const Load *op, SpvId index_id, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class);
void load_from_vector_index(const Load *op, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class);
void store_at_scalar_index(const Store *op, SpvId index_id, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class, SpvId value_id);
void store_at_vector_index(const Store *op, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class, SpvId value_id);
SpvFactory::Components split_vector(Type type, SpvId value_id);
SpvId join_vector(Type type, const SpvFactory::Components &value_components);
SpvId cast_type(Type target_type, Type value_type, SpvId value_id);
SpvId convert_to_bool(Type target_type, Type value_type, SpvId value_id);
// Returns Phi node inputs.
template<typename StmtOrExpr>
SpvFactory::BlockVariables emit_if_then_else(const Expr &condition, StmtOrExpr then_case, StmtOrExpr else_case);
template<typename T>
SpvId declare_constant_int(Type value_type, int64_t value);
template<typename T>
SpvId declare_constant_uint(Type value_type, uint64_t value);
template<typename T>
SpvId declare_constant_float(Type value_type, float value);
// Map from Halide built-in names to extended GLSL intrinsics for SPIR-V
using BuiltinMap = std::unordered_map<std::string, SpvId>;
const BuiltinMap glsl_builtin = {
{"acos_f16", GLSLstd450Acos},
{"acos_f32", GLSLstd450Acos},
{"acosh_f16", GLSLstd450Acosh},
{"acosh_f32", GLSLstd450Acosh},
{"asin_f16", GLSLstd450Asin},
{"asin_f32", GLSLstd450Asin},
{"asinh_f16", GLSLstd450Asinh},
{"asinh_f32", GLSLstd450Asinh},
{"atan2_f16", GLSLstd450Atan2},
{"atan2_f32", GLSLstd450Atan2},
{"atan_f16", GLSLstd450Atan},
{"atan_f32", GLSLstd450Atan},
{"atanh_f16", GLSLstd450Atanh},
{"atanh_f32", GLSLstd450Atanh},
{"ceil_f16", GLSLstd450Ceil},
{"ceil_f32", GLSLstd450Ceil},
{"cos_f16", GLSLstd450Cos},
{"cos_f32", GLSLstd450Cos},
{"cosh_f16", GLSLstd450Cosh},
{"cosh_f32", GLSLstd450Cosh},
{"exp_f16", GLSLstd450Exp},
{"exp_f32", GLSLstd450Exp},
{"fast_inverse_sqrt_f16", GLSLstd450InverseSqrt},
{"fast_inverse_sqrt_f32", GLSLstd450InverseSqrt},
{"fast_log_f16", GLSLstd450Log},
{"fast_log_f32", GLSLstd450Log},
{"fast_exp_f16", GLSLstd450Exp},
{"fast_exp_f32", GLSLstd450Exp},
{"fast_pow_f16", GLSLstd450Pow},
{"fast_pow_f32", GLSLstd450Pow},
{"floor_f16", GLSLstd450Floor},
{"floor_f32", GLSLstd450Floor},
{"log_f16", GLSLstd450Log},
{"log_f32", GLSLstd450Log},
{"sin_f16", GLSLstd450Sin},
{"sin_f32", GLSLstd450Sin},
{"sinh_f16", GLSLstd450Sinh},
{"sinh_f32", GLSLstd450Sinh},
{"sqrt_f16", GLSLstd450Sqrt},
{"sqrt_f32", GLSLstd450Sqrt},
{"tan_f16", GLSLstd450Tan},
{"tan_f32", GLSLstd450Tan},
{"tanh_f16", GLSLstd450Tanh},
{"tanh_f32", GLSLstd450Tanh},
{"trunc_f16", GLSLstd450Trunc},
{"trunc_f32", GLSLstd450Trunc},
{"mix", GLSLstd450FMix},
};
// The SPIRV-IR builder
SpvBuilder builder;
// The scope contains both the symbol id and its storage class
using SymbolIdStorageClassPair = std::pair<SpvId, SpvStorageClass>;
using SymbolScope = Scope<SymbolIdStorageClassPair>;
using ScopedSymbolBinding = ScopedBinding<SymbolIdStorageClassPair>;
SymbolScope symbol_table;
// Map from a variable ID to its corresponding storage type definition
struct StorageAccess {
SpvStorageClass storage_class = SpvStorageClassMax;
uint32_t storage_array_size = 0; // zero if not an array
SpvId storage_type_id = SpvInvalidId;
Type storage_type;
};
using StorageAccessMap = std::unordered_map<SpvId, StorageAccess>;
StorageAccessMap storage_access_map;
// Defines the binding information for a specialization constant
// that is exported by the module and can be overriden at runtime
struct SpecializationBinding {
SpvId constant_id = 0;
uint32_t type_size = 0;
std::string constant_name;
};
using SpecializationConstants = std::vector<SpecializationBinding>;
// Defines a shared memory allocation
struct SharedMemoryAllocation {
SpvId constant_id = 0; // specialization constant to dynamically adjust array size (zero if not used)
uint32_t array_size = 0;
uint32_t type_size = 0;
std::string variable_name;
};
using SharedMemoryUsage = std::vector<SharedMemoryAllocation>;
// Defines the specialization constants used for dynamically overiding the dispatch size
struct WorkgroupSizeBinding {
SpvId local_size_constant_id[3] = {0, 0, 0}; // zero if unused
};
// Keep track of the descriptor sets so we can add a sidecar to the
// module indicating which descriptor set to use for each entry point
struct DescriptorSet {
std::string entry_point_name;
uint32_t uniform_buffer_count = 0;
uint32_t storage_buffer_count = 0;
SpecializationConstants specialization_constants;
SharedMemoryUsage shared_memory_usage;
WorkgroupSizeBinding workgroup_size_binding;
};
using DescriptorSetTable = std::vector<DescriptorSet>;
DescriptorSetTable descriptor_set_table;
// The workgroup size ... this indicates the extents of the 1-3 dimensional index space
// used as part of the kernel dispatch. It can also be used to adjust the layout for work
// items (aka GPU threads), based on logical groupings. If a zero sized workgroup is
// encountered during CodeGen, it is assumed that the extents are dynamic and specified
// at runtime
uint32_t workgroup_size[3];
// Current index of kernel for module
uint32_t kernel_index = 0;
// Target for codegen
Target target;
} emitter;
std::string current_kernel_name;
};
// Check if all loads and stores to the member 'buffer' are dense, aligned, and
// have the same number of lanes. If this is indeed the case then the 'lanes'
// member stores the number of lanes in those loads and stores.
//
// FIXME: Refactor this and the version in CodeGen_OpenGLCompute_Dev to a common place!
//
class CheckAlignedDenseVectorLoadStore : public IRVisitor {
public:
// True if all loads and stores from the buffer are dense, aligned, and all
// have the same number of lanes, false otherwise.
bool are_all_dense = true;
// The number of lanes in the loads and stores. If the number of lanes is
// variable, then are_all_dense is set to false regardless, and this value
// is undefined. Initially set to -1 before any dense operation is
// discovered.
int lanes = -1;
CheckAlignedDenseVectorLoadStore(std::string name)
: buffer_name(std::move(name)) {
}
private:
// The name of the buffer to check.
std::string buffer_name;
using IRVisitor::visit;
void visit(const Load *op) override {
IRVisitor::visit(op);
if (op->name != buffer_name) {
return;
}
if (op->type.is_scalar()) {
are_all_dense = false;
return;
}
Expr ramp_base = strided_ramp_base(op->index);
if (!ramp_base.defined()) {
are_all_dense = false;
return;
}
if ((op->alignment.modulus % op->type.lanes() != 0) ||
(op->alignment.remainder % op->type.lanes() != 0)) {
are_all_dense = false;
return;
}
if (lanes != -1 && op->type.lanes() != lanes) {
are_all_dense = false;
return;
}
lanes = op->type.lanes();
}
void visit(const Store *op) override {
IRVisitor::visit(op);
if (op->name != buffer_name) {
return;
}
if (op->value.type().is_scalar()) {
are_all_dense = false;
return;
}
Expr ramp_base = strided_ramp_base(op->index);
if (!ramp_base.defined()) {
are_all_dense = false;
return;
}
if ((op->alignment.modulus % op->value.type().lanes() != 0) ||
(op->alignment.remainder % op->value.type().lanes() != 0)) {
are_all_dense = false;
return;
}
if (lanes != -1 && op->value.type().lanes() != lanes) {
are_all_dense = false;
return;
}
lanes = op->value.type().lanes();
}
};
struct FindWorkGroupSize : public IRVisitor {
using IRVisitor::visit;
void visit(const For *loop) override {
if (!CodeGen_GPU_Dev::is_gpu_var(loop->name)) {
return loop->body.accept(this);
}
if ((loop->for_type == ForType::GPUBlock) ||
(loop->for_type == ForType::GPUThread)) {
// This should always be true at this point in codegen
internal_assert(is_const_zero(loop->min));
// Save & validate the workgroup size
int index = thread_loop_workgroup_index(loop->name);
if (index >= 0) {
const IntImm *literal = loop->extent.as<IntImm>();
if (literal != nullptr) {
uint32_t new_wg_size = literal->value;
user_assert(workgroup_size[index] == 0 || workgroup_size[index] == new_wg_size)
<< "Vulkan requires all kernels have the same workgroup size, "
<< "but two different sizes were encountered: "
<< workgroup_size[index] << " and "
<< new_wg_size << " in dimension " << index << "\n";
workgroup_size[index] = new_wg_size;
}
}
debug(4) << "Thread group size for index " << index << " is " << workgroup_size[index] << "\n";
}
loop->body.accept(this);
}
int thread_loop_workgroup_index(const std::string &name) {
std::string ids[] = {".__thread_id_x",
".__thread_id_y",
".__thread_id_z"};
for (size_t i = 0; i < sizeof(ids) / sizeof(std::string); i++) {
if (ends_with(name, ids[i])) {
return i;
}
}
return -1;
}
uint32_t workgroup_size[3] = {0, 0, 0};
};
CodeGen_Vulkan_Dev::SPIRV_Emitter::SPIRV_Emitter(Target t)
: IRVisitor(), target(t) {
// Empty
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::scalarize(const Expr &e) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::scalarize(): " << (Expr)e << "\n";
internal_assert(e.type().is_vector()) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::scalarize must be called with an expression of vector type.\n";
SpvId type_id = builder.declare_type(e.type());
SpvId value_id = builder.declare_null_constant(e.type());
SpvId result_id = value_id;
for (int i = 0; i < e.type().lanes(); i++) {
extract_lane(e, i).accept(this);
SpvId extracted_id = builder.current_id();
SpvId composite_id = builder.reserve_id(SpvResultId);
SpvFactory::Indices indices = {(uint32_t)i};
builder.append(SpvFactory::composite_insert(type_id, composite_id, extracted_id, value_id, indices));
result_id = composite_id;
}
builder.update_id(result_id);
}
SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::map_type_to_pair(const Type &t) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::map_type_to_pair(): " << t << "\n";
SpvId base_type_id = builder.declare_type(t);
SpvBuilder::StructMemberTypes member_type_ids = {base_type_id, base_type_id};
const std::string struct_name = std::string("_struct_") + type_to_c_type(t, false, false) + std::string("_pair");
SpvId struct_type_id = builder.declare_struct(struct_name, member_type_ids);
return struct_type_id;
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Variable *var) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Variable): " << var->type << " " << var->name << "\n";
SpvId variable_id = symbol_table.get(var->name).first;
user_assert(variable_id != SpvInvalidId) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Variable): Invalid symbol name!\n";
builder.update_id(variable_id);
}
template<typename T>
SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::declare_constant_int(Type value_type, int64_t value) {
const T typed_value = (T)(value);
SpvId constant_id = builder.declare_constant(value_type, &typed_value);
builder.update_id(constant_id);
return constant_id;
}
template<typename T>
SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::declare_constant_uint(Type value_type, uint64_t value) {
const T typed_value = (T)(value);
SpvId constant_id = builder.declare_constant(value_type, &typed_value);
builder.update_id(constant_id);
return constant_id;
}
template<typename T>
SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::declare_constant_float(Type value_type, float value) {
const T typed_value = (T)(value);
SpvId constant_id = builder.declare_constant(value_type, &typed_value);
builder.update_id(constant_id);
return constant_id;
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const IntImm *imm) {
if (imm->type.bits() == 8) {
declare_constant_int<int8_t>(imm->type, imm->value);
} else if (imm->type.bits() == 16) {
declare_constant_int<int16_t>(imm->type, imm->value);
} else if (imm->type.bits() == 32) {
declare_constant_int<int32_t>(imm->type, imm->value);
} else if (imm->type.bits() == 64) {
declare_constant_int<int64_t>(imm->type, imm->value);
} else {
internal_error << "Vulkan backend currently only supports 8-bit, 16-bit, 32-bit or 64-bit signed integers!\n";
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const UIntImm *imm) {
if (imm->type.bits() == 8) {
declare_constant_uint<uint8_t>(imm->type, imm->value);
} else if (imm->type.bits() == 16) {
declare_constant_uint<uint16_t>(imm->type, imm->value);
} else if (imm->type.bits() == 32) {
declare_constant_uint<uint32_t>(imm->type, imm->value);
} else if (imm->type.bits() == 64) {
declare_constant_uint<uint64_t>(imm->type, imm->value);
} else {
internal_error << "Vulkan backend currently only supports 8-bit, 16-bit, 32-bit or 64-bit unsigned integers!\n";
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const StringImm *imm) {
SpvId constant_id = builder.declare_string_constant(imm->value);
builder.update_id(constant_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const FloatImm *imm) {
if (imm->type.bits() == 16) {
if (imm->type.is_bfloat()) {
declare_constant_float<bfloat16_t>(imm->type, imm->value);
} else {
declare_constant_float<float16_t>(imm->type, imm->value);
}
} else if (imm->type.bits() == 32) {
declare_constant_float<float>(imm->type, imm->value);
} else if (imm->type.bits() == 64) {
declare_constant_float<double>(imm->type, imm->value);
} else {
internal_error << "Vulkan backend currently only supports 16-bit, 32-bit or 64-bit floats\n";
}
}
template<typename T>
void fill_bytes_with_value(uint8_t *bytes, int count, int value) {
T *v = reinterpret_cast<T *>(bytes);
for (int i = 0; i < count; ++i) {
v[i] = (T)value;
}
}
SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(Type target_type, Type value_type, SpvId value_id) {
if (!value_type.is_bool()) {
value_id = cast_type(Bool(), value_type, value_id);
}
const int true_value = 1;
const int false_value = 0;
std::vector<uint8_t> true_data(target_type.bytes(), (uint8_t)0);
std::vector<uint8_t> false_data(target_type.bytes(), (uint8_t)0);
if (target_type.is_int_or_uint() && target_type.bits() == 8) {
fill_bytes_with_value<int8_t>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<int8_t>(&false_data[0], target_type.lanes(), false_value);
} else if (target_type.is_int_or_uint() && target_type.bits() == 16) {
fill_bytes_with_value<int16_t>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<int16_t>(&false_data[0], target_type.lanes(), false_value);
} else if (target_type.is_int_or_uint() && target_type.bits() == 32) {
fill_bytes_with_value<int32_t>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<int32_t>(&false_data[0], target_type.lanes(), false_value);
} else if (target_type.is_int_or_uint() && target_type.bits() == 64) {
fill_bytes_with_value<int64_t>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<int64_t>(&false_data[0], target_type.lanes(), false_value);
} else if (target_type.is_float() && target_type.bits() == 16) {
if (target_type.is_bfloat()) {
fill_bytes_with_value<bfloat16_t>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<bfloat16_t>(&false_data[0], target_type.lanes(), false_value);
} else {
fill_bytes_with_value<float16_t>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<float16_t>(&false_data[0], target_type.lanes(), false_value);
}
} else if (target_type.is_float() && target_type.bits() == 32) {
fill_bytes_with_value<float>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<float>(&false_data[0], target_type.lanes(), false_value);
} else if (target_type.is_float() && target_type.bits() == 64) {
fill_bytes_with_value<double>(&true_data[0], target_type.lanes(), true_value);
fill_bytes_with_value<double>(&false_data[0], target_type.lanes(), false_value);
} else {
user_error << "Unhandled type cast from value type '" << value_type << "' to target type '" << target_type << "'!";
}
SpvId result_id = builder.reserve_id(SpvResultId);
SpvId target_type_id = builder.declare_type(target_type);
SpvId true_value_id = builder.declare_constant(target_type, &true_data);
SpvId false_value_id = builder.declare_constant(target_type, &false_data);
builder.append(SpvFactory::select(target_type_id, result_id, value_id, true_value_id, false_value_id));
return result_id;
}
SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::cast_type(Type target_type, Type value_type, SpvId value_id) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::cast_type(): casting from value type '"
<< value_type << "' to target type '" << target_type << "'!\n";
if (value_type == target_type) {
return value_id;
}
SpvOp op_code = SpvOpNop;
if (value_type.is_float()) {
if (target_type.is_float()) {
op_code = SpvOpFConvert;
} else if (target_type.is_bool()) {
op_code = SpvOpSelect;
} else if (target_type.is_uint()) {
op_code = SpvOpConvertFToU;
} else if (target_type.is_int()) {
op_code = SpvOpConvertFToS;
}
} else if (value_type.is_bool()) {
op_code = SpvOpSelect;
} else if (value_type.is_uint()) {
if (target_type.is_float()) {
op_code = SpvOpConvertUToF;
} else if (target_type.is_bool()) {
op_code = SpvOpSelect;
} else if (target_type.is_int_or_uint()) {
op_code = SpvOpUConvert;
}
} else if (value_type.is_int()) {
if (target_type.is_float()) {
op_code = SpvOpConvertSToF;
} else if (target_type.is_bool()) {
op_code = SpvOpSelect;
} else if (target_type.is_int_or_uint()) {
op_code = SpvOpSConvert;
}
}
// If none of the explicit conversions matched, do a direct bitcast if the total
// size of both types is the same
if (op_code == SpvOpNop) {
if (target_type.bytes() == value_type.bytes()) {
op_code = SpvOpBitcast;
}
}
// Error If we still didn't find a suitable cast ...
if (op_code == SpvOpNop) {
user_error << "Unhandled type cast from value type '" << value_type << "' to target type '" << target_type << "'!";
return SpvInvalidId;
}
SpvId result_id = SpvInvalidId;
SpvId target_type_id = builder.declare_type(target_type);
if (op_code == SpvOpBitcast) {
result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::bitcast(target_type_id, result_id, value_id));
} else if (op_code == SpvOpSelect) {
result_id = convert_to_bool(target_type, value_type, value_id);
} else if (op_code == SpvOpUConvert && target_type.is_int()) {
// SPIR-V requires both value and target types to be unsigned and of
// different component bit widths in order to be compatible with UConvert
// ... so do the conversion to an equivalent unsigned type then bitcast this
// result into the target type
Type unsigned_type = target_type.with_code(halide_type_uint);
if (unsigned_type.bytes() != value_type.bytes()) {
SpvId unsigned_type_id = builder.declare_type(unsigned_type);
SpvId unsigned_value_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::convert(op_code, unsigned_type_id, unsigned_value_id, value_id));
value_id = unsigned_value_id;
}
result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::bitcast(target_type_id, result_id, value_id));
} else if (op_code == SpvOpSConvert && target_type.is_uint()) {
// Same as above but for SConvert
Type signed_type = target_type.with_code(halide_type_int);
if (signed_type.bytes() != value_type.bytes()) {
SpvId signed_type_id = builder.declare_type(signed_type);
SpvId signed_value_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::convert(op_code, signed_type_id, signed_value_id, value_id));
value_id = signed_value_id;
}
result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::bitcast(target_type_id, result_id, value_id));
} else {
result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::convert(op_code, target_type_id, result_id, value_id));
}
return result_id;
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Cast *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Cast): " << op->value.type() << " to " << op->type << "\n";
Type value_type = op->value.type();
Type target_type = op->type;
op->value.accept(this);
SpvId value_id = builder.current_id();
if ((value_type.is_vector() && target_type.is_vector())) {
if (value_type.lanes() == target_type.lanes()) {
SpvId result_id = cast_type(target_type, value_type, value_id);
builder.update_id(result_id);
} else {
user_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Cast): unhandled case " << op->value.type() << " to " << op->type << " (incompatible lanes)\n";
}
} else if (value_type.is_scalar() && target_type.is_scalar()) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Cast): scalar type (cast)\n";
SpvId result_id = cast_type(target_type, value_type, value_id);
builder.update_id(result_id);
} else if (value_type.bytes() == target_type.bytes()) {
SpvId result_id = cast_type(target_type, value_type, value_id);
builder.update_id(result_id);
} else {
user_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Cast): unhandled case " << op->value.type() << " to " << op->type << "\n";
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Reinterpret *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Reinterpret): " << op->value.type() << " to " << op->type << "\n";
SpvId type_id = builder.declare_type(op->type);
op->value.accept(this);
SpvId src_id = builder.current_id();
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::bitcast(type_id, result_id, src_id));
builder.update_id(result_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Add *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Add): " << op->type << " ((" << op->a << ") + (" << op->b << "))\n";
visit_binary_op(op->type.is_float() ? SpvOpFAdd : SpvOpIAdd, op->type, op->a, op->b);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Sub *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Sub): " << op->type << " ((" << op->a << ") - (" << op->b << "))\n";
visit_binary_op(op->type.is_float() ? SpvOpFSub : SpvOpISub, op->type, op->a, op->b);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Mul *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Mul): " << op->type << " ((" << op->a << ") * (" << op->b << "))\n";
visit_binary_op(op->type.is_float() ? SpvOpFMul : SpvOpIMul, op->type, op->a, op->b);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Div *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Div): " << op->type << " ((" << op->a << ") / (" << op->b << "))\n";
user_assert(!is_const_zero(op->b)) << "Division by constant zero in expression: " << Expr(op) << "\n";
if (op->type.is_int()) {
Expr e = lower_euclidean_div(op->a, op->b);
e.accept(this);
} else if (op->type.is_uint()) {
visit_binary_op(SpvOpUDiv, op->type, op->a, op->b);
} else if (op->type.is_float()) {
visit_binary_op(SpvOpFDiv, op->type, op->a, op->b);
} else {
internal_error << "Failed to find a suitable Div operator for type: " << op->type << "\n";
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Mod *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Mod): " << op->type << " ((" << op->a << ") % (" << op->b << "))\n";
int bits = 0;
if (is_const_power_of_two_integer(op->b, &bits) && op->type.is_int_or_uint()) {
op->a.accept(this);
SpvId src_a_id = builder.current_id();
int bitwise_value = ((1 << bits) - 1);
Expr expr = make_const(op->type, bitwise_value);
expr.accept(this);
SpvId src_b_id = builder.current_id();
SpvId type_id = builder.declare_type(op->type);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::binary_op(SpvOpBitwiseAnd, type_id, result_id, src_a_id, src_b_id));
builder.update_id(result_id);
} else if (op->type.is_int() || op->type.is_uint()) {
// Just exploit the Euclidean identity
Expr zero = make_zero(op->type);
Expr equiv = select(op->a == zero, zero,
op->a - (op->a / op->b) * op->b);
equiv = common_subexpression_elimination(equiv);
equiv.accept(this);
} else if (op->type.is_float()) {
// SPIR-V FMod is strangely not what we want .. FRem does what we need
visit_binary_op(SpvOpFRem, op->type, op->a, op->b);
} else {
internal_error << "Failed to find a suitable Mod operator for type: " << op->type << "\n";
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Max *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Max): " << op->type << " Max((" << op->a << "), (" << op->b << "))\n";
SpvId op_code = SpvOpNop;
if (op->type.is_float()) {
op_code = GLSLstd450FMax;
} else if (op->type.is_int()) {
op_code = GLSLstd450SMax;
} else if (op->type.is_uint()) {
op_code = GLSLstd450UMax;
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Max *op): unhandled type: " << op->type << "\n";
}
std::vector<Expr> args;
args.reserve(2);
if (op->type.is_vector()) {
if (op->a.type().is_scalar()) {
Expr a_vector = Broadcast::make(op->a, op->type.lanes());
args.push_back(a_vector);
} else {
args.push_back(op->a);
}
if (op->b.type().is_scalar()) {
Expr b_vector = Broadcast::make(op->b, op->type.lanes());
args.push_back(b_vector);
} else {
args.push_back(op->b);
}
} else {
args.push_back(op->a);
args.push_back(op->b);
}
visit_glsl_op(op_code, op->type, args);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Min *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Min): " << op->type << " Min((" << op->a << "), (" << op->b << "))\n";
SpvId op_code = SpvOpNop;
if (op->type.is_float()) {
op_code = GLSLstd450FMin;
} else if (op->type.is_int()) {
op_code = GLSLstd450SMin;
} else if (op->type.is_uint()) {
op_code = GLSLstd450UMin;
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Min *op): unhandled type: " << op->type << "\n";
}
std::vector<Expr> args;
args.reserve(2);
if (op->type.is_vector()) {
if (op->a.type().is_scalar()) {
Expr a_vector = Broadcast::make(op->a, op->type.lanes());
args.push_back(a_vector);
} else {
args.push_back(op->a);
}
if (op->b.type().is_scalar()) {
Expr b_vector = Broadcast::make(op->b, op->type.lanes());
args.push_back(b_vector);
} else {
args.push_back(op->b);
}
} else {
args.push_back(op->a);
args.push_back(op->b);
}
visit_glsl_op(op_code, op->type, args);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const EQ *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(EQ): " << op->type << " (" << op->a << ") == (" << op->b << ")\n";
if (op->a.type() != op->b.type()) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const EQ *op): Mismatched operand types: " << op->a.type() << " != " << op->b.type() << "\n";
}
SpvOp op_code = SpvOpNop;
if (op->a.type().is_float()) {
op_code = SpvOpFOrdEqual;
} else {
op_code = SpvOpIEqual;
}
Type bool_type = UInt(1, op->type.lanes());
visit_binary_op(op_code, bool_type, op->a, op->b);
if (!op->type.is_bool()) {
SpvId current_id = builder.current_id();
SpvId result_id = cast_type(op->type, bool_type, current_id);
builder.update_id(result_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const NE *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(NE): " << op->type << " (" << op->a << ") != (" << op->b << ")\n";
if (op->a.type() != op->b.type()) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const NE *op): Mismatched operand types: " << op->a.type() << " != " << op->b.type() << "\n";
}
SpvOp op_code = SpvOpNop;
if (op->a.type().is_float()) {
op_code = SpvOpFOrdNotEqual;
} else {
op_code = SpvOpINotEqual;
}
Type bool_type = UInt(1, op->type.lanes());
visit_binary_op(op_code, bool_type, op->a, op->b);
if (!op->type.is_bool()) {
Type bool_type = UInt(1, op->type.lanes());
SpvId current_id = builder.current_id();
SpvId result_id = cast_type(op->type, bool_type, current_id);
builder.update_id(result_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LT *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(LT): " << op->type << " (" << op->a << ") < (" << op->b << ")\n";
if (op->a.type() != op->b.type()) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LT *op): Mismatched operand types: " << op->a.type() << " != " << op->b.type() << "\n";
}
SpvOp op_code = SpvOpNop;
if (op->a.type().is_float()) {
op_code = SpvOpFOrdLessThan;
} else if (op->a.type().is_int()) {
op_code = SpvOpSLessThan;
} else if (op->a.type().is_uint()) {
op_code = SpvOpULessThan;
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LT *op): unhandled type: " << op->a.type() << "\n";
}
Type bool_type = UInt(1, op->type.lanes());
visit_binary_op(op_code, bool_type, op->a, op->b);
if (!op->type.is_bool()) {
Type bool_type = UInt(1, op->type.lanes());
SpvId current_id = builder.current_id();
SpvId result_id = cast_type(op->type, bool_type, current_id);
builder.update_id(result_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LE *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(LE): " << op->type << " (" << op->a << ") <= (" << op->b << ")\n";
if (op->a.type() != op->b.type()) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LE *op): Mismatched operand types: " << op->a.type() << " != " << op->b.type() << "\n";
}
SpvOp op_code = SpvOpNop;
if (op->a.type().is_float()) {
op_code = SpvOpFOrdLessThanEqual;
} else if (op->a.type().is_int()) {
op_code = SpvOpSLessThanEqual;
} else if (op->a.type().is_uint()) {
op_code = SpvOpULessThanEqual;
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LE *op): unhandled type: " << op->a.type() << "\n";
}
Type bool_type = UInt(1, op->type.lanes());
visit_binary_op(op_code, bool_type, op->a, op->b);
if (!op->type.is_bool()) {
Type bool_type = UInt(1, op->type.lanes());
SpvId current_id = builder.current_id();
SpvId result_id = cast_type(op->type, bool_type, current_id);
builder.update_id(result_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const GT *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(GT): " << op->type << " (" << op->a << ") > (" << op->b << ")\n";
if (op->a.type() != op->b.type()) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const GT *op): Mismatched operand types: " << op->a.type() << " != " << op->b.type() << "\n";
}
SpvOp op_code = SpvOpNop;
if (op->a.type().is_float()) {
op_code = SpvOpFOrdGreaterThan;
} else if (op->a.type().is_int()) {
op_code = SpvOpSGreaterThan;
} else if (op->a.type().is_uint()) {
op_code = SpvOpUGreaterThan;
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const GT *op): unhandled type: " << op->a.type() << "\n";
}
Type bool_type = UInt(1, op->type.lanes());
visit_binary_op(op_code, bool_type, op->a, op->b);
if (!op->type.is_bool()) {
Type bool_type = UInt(1, op->type.lanes());
SpvId current_id = builder.current_id();
SpvId result_id = cast_type(op->type, bool_type, current_id);
builder.update_id(result_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const GE *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(GE): " << op->type << " (" << op->a << ") >= (" << op->b << ")\n";
if (op->a.type() != op->b.type()) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LE *op): Mismatched operand types: " << op->a.type() << " != " << op->b.type() << "\n";
}
SpvOp op_code = SpvOpNop;
if (op->a.type().is_float()) {
op_code = SpvOpFOrdGreaterThanEqual;
} else if (op->a.type().is_int()) {
op_code = SpvOpSGreaterThanEqual;
} else if (op->a.type().is_uint()) {
op_code = SpvOpUGreaterThanEqual;
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const GE *op): unhandled type: " << op->a.type() << "\n";
}
Type bool_type = UInt(1, op->type.lanes());
visit_binary_op(op_code, bool_type, op->a, op->b);
if (!op->type.is_bool()) {
Type bool_type = UInt(1, op->type.lanes());
SpvId current_id = builder.current_id();
SpvId result_id = cast_type(op->type, bool_type, current_id);
builder.update_id(result_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const And *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(And): " << op->type << " (" << op->a << ") && (" << op->b << ")\n";
visit_binary_op(SpvOpLogicalAnd, op->type, op->a, op->b);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Or *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Or): " << op->type << " (" << op->a << ") || (" << op->b << ")\n";
visit_binary_op(SpvOpLogicalOr, op->type, op->a, op->b);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Not *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Not): " << op->type << " !(" << op->a << ")\n";
visit_unary_op(SpvOpLogicalNot, op->type, op->a);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const ProducerConsumer *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(ProducerConsumer): name=" << op->name << " is_producer=" << (op->is_producer ? "true" : "false") << "\n";
op->body.accept(this);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Call *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Call): " << op->type << " " << op->name << " args=" << (uint32_t)op->args.size() << "\n";
if (op->is_intrinsic(Call::gpu_thread_barrier)) {
internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n";
const auto *fence_type_ptr = as_const_int(op->args[0]);
internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n";
auto fence_type = *fence_type_ptr;
// Follow GLSL semantics for GLCompute ...
//
// barrier() -> control_barrier(Workgroup, Workgroup, AcquireRelease | WorkgroupMemory)
//
uint32_t execution_scope = SpvWorkgroupScope;
uint32_t memory_scope = SpvWorkgroupScope;
uint32_t control_mask = (SpvMemorySemanticsAcquireReleaseMask | SpvMemorySemanticsWorkgroupMemoryMask);
SpvId exec_scope_id = builder.declare_constant(UInt(32), &execution_scope);
SpvId memory_scope_id = builder.declare_constant(UInt(32), &memory_scope);
SpvId control_mask_id = builder.declare_constant(UInt(32), &control_mask);
builder.append(SpvFactory::control_barrier(exec_scope_id, memory_scope_id, control_mask_id));
if ((fence_type & CodeGen_GPU_Dev::MemoryFenceType::Device) ||
(fence_type & CodeGen_GPU_Dev::MemoryFenceType::Shared)) {
// groupMemoryBarrier() -> memory_barrier(Workgroup, AcquireRelease | UniformMemory | WorkgroupMemory | ImageMemory)
//
uint32_t memory_mask = (SpvMemorySemanticsAcquireReleaseMask |
SpvMemorySemanticsUniformMemoryMask |
SpvMemorySemanticsWorkgroupMemoryMask |
SpvMemorySemanticsImageMemoryMask);
SpvId memory_mask_id = builder.declare_constant(UInt(32), &memory_mask);
builder.append(SpvFactory::memory_barrier(memory_scope_id, memory_mask_id));
}
SpvId result_id = builder.declare_null_constant(op->type);
builder.update_id(result_id);
} else if (op->is_intrinsic(Call::abs)) {
internal_assert(op->args.size() == 1);
SpvId op_code = SpvInvalidId;
if (op->type.is_float()) {
op_code = GLSLstd450FAbs;
} else {
op_code = GLSLstd450SAbs;
}
visit_glsl_op(op_code, op->type, op->args);
} else if (op->is_intrinsic(Call::IntrinsicOp::round)) {
internal_assert(op->args.size() == 1);
// GLSL RoundEven matches Halide's implementation
visit_glsl_op(GLSLstd450RoundEven, op->type, op->args);
} else if (op->is_intrinsic(Call::absd)) {
internal_assert(op->args.size() == 2);
Expr a = op->args[0];
Expr b = op->args[1];
Expr e = cast(op->type, select(a < b, b - a, a - b));
e->accept(this);
} else if (op->is_intrinsic(Call::return_second)) {
internal_assert(op->args.size() == 2);
// Simply discard the first argument, which is generally a call to
// 'halide_printf'.
if (op->args[1].defined()) {
op->args[1]->accept(this);
}
} else if (op->is_intrinsic(Call::bitwise_and)) {
internal_assert(op->args.size() == 2);
visit_binary_op(SpvOpBitwiseAnd, op->type, op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::bitwise_xor)) {
internal_assert(op->args.size() == 2);
visit_binary_op(SpvOpBitwiseXor, op->type, op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::bitwise_or)) {
internal_assert(op->args.size() == 2);
visit_binary_op(SpvOpBitwiseOr, op->type, op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::bitwise_not)) {
internal_assert(op->args.size() == 1);
visit_unary_op(SpvOpNot, op->type, op->args[0]);
} else if (op->is_intrinsic(Call::if_then_else)) {
Expr cond = op->args[0];
if (const Broadcast *b = cond.as<Broadcast>()) {
cond = b->value;
}
if (cond.type().is_vector()) {
scalarize(op);
} else {
// Generate Phi node if used as an expression.
internal_assert(op->args.size() == 2 || op->args.size() == 3);
Expr else_expr;
if (op->args.size() == 3) {
else_expr = op->args[2];
}
SpvFactory::BlockVariables block_vars = emit_if_then_else(op->args[0], op->args[1], else_expr);
SpvId type_id = builder.declare_type(op->type);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::phi(type_id, result_id, block_vars));
builder.update_id(result_id);
}
} else if (op->is_intrinsic(Call::IntrinsicOp::div_round_to_zero)) {
internal_assert(op->args.size() == 2);
// See if we can rewrite it to something faster (e.g. a shift)
Expr e = lower_int_uint_div(op->args[0], op->args[1], /** round to zero */ true);
if (!e.as<Call>()) {
e.accept(this);
return;
}
SpvOp op_code = SpvOpNop;
if (op->type.is_float()) {
op_code = SpvOpFDiv;
} else if (op->type.is_int()) {
op_code = SpvOpSDiv;
} else if (op->type.is_uint()) {
op_code = SpvOpUDiv;
} else {
internal_error << "div_round_to_zero of unhandled type.\n";
}
visit_binary_op(op_code, op->type, op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::IntrinsicOp::mod_round_to_zero)) {
internal_assert(op->args.size() == 2);
SpvOp op_code = SpvOpNop;
if (op->type.is_float()) {
op_code = SpvOpFRem; // NOTE: FRem matches the fmod we expect
} else if (op->type.is_int()) {
op_code = SpvOpSMod;
} else if (op->type.is_uint()) {
op_code = SpvOpUMod;
} else {
internal_error << "mod_round_to_zero of unhandled type.\n";
}
visit_binary_op(op_code, op->type, op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::shift_right)) {
internal_assert(op->args.size() == 2);
if (op->type.is_uint() || (op->args[1].type().is_uint())) {
visit_binary_op(SpvOpShiftRightLogical, op->type, op->args[0], op->args[1]);
} else {
Expr e = lower_signed_shift_right(op->args[0], op->args[1]);
e.accept(this);
}
} else if (op->is_intrinsic(Call::shift_left)) {
internal_assert(op->args.size() == 2);
if (op->type.is_uint() || (op->args[1].type().is_uint())) {
visit_binary_op(SpvOpShiftLeftLogical, op->type, op->args[0], op->args[1]);
} else {
Expr e = lower_signed_shift_left(op->args[0], op->args[1]);
e.accept(this);
}
} else if (op->is_intrinsic(Call::strict_float)) {
// TODO: Enable/Disable RelaxedPrecision flags?
internal_assert(op->args.size() == 1);
op->args[0].accept(this);
} else if (op->is_intrinsic(Call::IntrinsicOp::sorted_avg)) {
internal_assert(op->args.size() == 2);
// b > a, so the following works without widening:
// a + (b - a)/2
Expr e = op->args[0] + (op->args[1] - op->args[0]) / 2;
e.accept(this);
} else if (op->is_intrinsic(Call::lerp)) {
// Implement lerp using GLSL's mix() function, which always uses
// floating point arithmetic.
Expr zero_val = op->args[0];
Expr one_val = op->args[1];
Expr weight = op->args[2];
internal_assert(weight.type().is_uint() || weight.type().is_float());
if (weight.type().is_uint()) {
// Normalize integer weights to [0.0f, 1.0f] range.
internal_assert(weight.type().bits() < 32);
weight = Div::make(Cast::make(Float(32), weight),
Cast::make(Float(32), weight.type().max()));
} else if (op->type.is_uint()) {
// Round float weights down to next multiple of (1/op->type.imax())
// to give same results as lerp based on integer arithmetic.
internal_assert(op->type.bits() < 32);
weight = floor(weight * op->type.max()) / op->type.max();
}
Type result_type = Float(32, op->type.lanes());
Expr e = Call::make(result_type, "mix", {zero_val, one_val, weight}, Call::Extern);
if (!op->type.is_float()) {
// Mirror rounding implementation of Halide's integer lerp.
e = Cast::make(op->type, floor(e + 0.5f));
}
e.accept(this);
} else if (op->is_intrinsic(Call::mux)) {
Expr e = lower_mux(op);
e.accept(this);
} else if (op->is_intrinsic(Call::saturating_cast)) {
Expr e = lower_intrinsic(op);
e.accept(this);
} else if (op->is_intrinsic()) {
Expr lowered = lower_intrinsic(op);
if (lowered.defined()) {
lowered.accept(this);
} else {
internal_error << "Unhandled intrinsic in Vulkan backend: " << op->name << "\n";
}
} else if (op->call_type == Call::PureExtern && starts_with(op->name, "pow_f")) {
internal_assert(op->args.size() == 2);
if (can_prove(op->args[0] > 0)) {
visit_glsl_op(GLSLstd450Pow, op->type, op->args);
} else {
Expr x = op->args[0];
Expr y = op->args[1];
Halide::Expr abs_x_pow_y = Internal::halide_exp(Internal::halide_log(abs(x)) * y);
Halide::Expr nan_expr = Call::make(x.type(), "nan_f32", {}, Call::PureExtern);
Expr iy = floor(y);
Expr one = make_one(x.type());
Expr zero = make_zero(x.type());
Expr e = select(x > 0, abs_x_pow_y, // Strictly positive x
y == 0.0f, one, // x^0 == 1
x == 0.0f, zero, // 0^y == 0
y != iy, nan_expr, // negative x to a non-integer power
iy % 2 == 0, abs_x_pow_y, // negative x to an even power
-abs_x_pow_y); // negative x to an odd power
e = common_subexpression_elimination(e);
e.accept(this);
}
} else if (starts_with(op->name, "fast_inverse_f")) {
internal_assert(op->args.size() == 1);
if (op->type.lanes() > 1) {
user_error << "Vulkan: Expected scalar value for fast_inverse!\n";
}
op->args[0].accept(this);
SpvId arg_value_id = builder.current_id();
SpvId one_constant_id = SpvInvalidId;
SpvId type_id = builder.declare_type(op->type);
if (op->type.is_float() && op->type.bits() == 16) {
if (op->type.is_bfloat()) {
bfloat16_t one_value = bfloat16_t(1.0f);
one_constant_id = builder.declare_constant(op->type, &one_value);
} else {
float16_t one_value = float16_t(1.0f);
one_constant_id = builder.declare_constant(op->type, &one_value);
}
} else if (op->type.is_float() && op->type.bits() == 32) {
float one_value = float(1.0f);
one_constant_id = builder.declare_constant(op->type, &one_value);
} else if (op->type.is_float() && op->type.bits() == 64) {
double one_value = double(1.0);
one_constant_id = builder.declare_constant(op->type, &one_value);
} else {
internal_error << "Vulkan: Unhandled float type in fast_inverse intrinsic!\n";
}
internal_assert(one_constant_id != SpvInvalidId);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::binary_op(SpvOpFDiv, type_id, result_id, one_constant_id, arg_value_id));
builder.update_id(result_id);
} else if (op->name == "nan_f32") {
float value = NAN;
SpvId result_id = builder.declare_constant(Float(32), &value);
builder.update_id(result_id);
} else if (op->name == "inf_f32") {
float value = INFINITY;
SpvId result_id = builder.declare_constant(Float(32), &value);
builder.update_id(result_id);
} else if (op->name == "neg_inf_f32") {
float value = -INFINITY;
SpvId result_id = builder.declare_constant(Float(32), &value);
builder.update_id(result_id);
} else if (starts_with(op->name, "is_nan_f")) {
internal_assert(op->args.size() == 1);
visit_unary_op((SpvOp)SpvOpIsNan, op->type, op->args[0]);
} else if (starts_with(op->name, "is_inf_f")) {
internal_assert(op->args.size() == 1);
visit_unary_op((SpvOp)SpvOpIsInf, op->type, op->args[0]);
} else if (starts_with(op->name, "is_finite_f")) {
internal_assert(op->args.size() == 1);
visit_unary_op((SpvOp)SpvOpIsInf, op->type, op->args[0]);
SpvId is_inf_id = builder.current_id();
visit_unary_op((SpvOp)SpvOpIsNan, op->type, op->args[0]);
SpvId is_nan_id = builder.current_id();
SpvId type_id = builder.declare_type(op->type);
SpvId not_is_nan_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::logical_not(type_id, not_is_nan_id, is_nan_id));
SpvId not_is_inf_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::logical_not(type_id, not_is_inf_id, is_inf_id));
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::logical_and(type_id, result_id, not_is_inf_id, not_is_nan_id));
builder.update_id(result_id);
} else {
// If its not a standard SPIR-V built-in, see if there's a GLSL extended builtin
BuiltinMap::const_iterator glsl_it = glsl_builtin.find(op->name);
if (glsl_it == glsl_builtin.end()) {
user_error << "Vulkan: unhandled SPIR-V GLSL builtin function '" << op->name << "' encountered.\n";
}
// Call the GLSL extended built-in
SpvId glsl_op_code = glsl_it->second;
visit_glsl_op(glsl_op_code, op->type, op->args);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Select *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Select): " << op->type << " (" << op->condition << ") ? (" << op->true_value << ") : (" << op->false_value << ")\n";
SpvId type_id = builder.declare_type(op->type);
op->condition.accept(this);
SpvId cond_id = builder.current_id();
op->true_value.accept(this);
SpvId true_id = builder.current_id();
op->false_value.accept(this);
SpvId false_id = builder.current_id();
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::select(type_id, result_id, cond_id, true_id, false_id));
builder.update_id(result_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::load_from_scalar_index(const Load *op, SpvId index_id, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::load_from_scalar_index(): "
<< "index_id=" << index_id << " "
<< "variable_id=" << variable_id << " "
<< "value_type=" << value_type << " "
<< "storage_type=" << storage_type << " "
<< "storage_class=" << storage_class << "\n";
// determine the base type id for the source value
SpvId base_type_id = builder.type_of(variable_id);
if (builder.is_pointer_type(base_type_id)) {
base_type_id = builder.lookup_base_type(base_type_id);
}
SpvId storage_type_id = builder.declare_type(storage_type);
SpvId ptr_type_id = builder.declare_pointer_type(storage_type, storage_class);
uint32_t zero = 0;
SpvId src_id = SpvInvalidId;
SpvId src_index_id = index_id;
if (storage_class == SpvStorageClassUniform) {
if (builder.is_struct_type(base_type_id)) {
SpvId zero_id = builder.declare_constant(UInt(32), &zero);
SpvFactory::Indices access_indices = {zero_id, src_index_id};
src_id = builder.declare_access_chain(ptr_type_id, variable_id, access_indices);
} else {
SpvFactory::Indices access_indices = {src_index_id};
src_id = builder.declare_access_chain(ptr_type_id, variable_id, access_indices);
}
} else if ((storage_class == SpvStorageClassWorkgroup) || (storage_class == SpvStorageClassFunction)) {
if (builder.is_array_type(base_type_id)) {
SpvFactory::Indices access_indices = {src_index_id};
src_id = builder.declare_access_chain(ptr_type_id, variable_id, access_indices);
} else {
src_id = variable_id;
}
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Load): unhandled storage class encountered on op: " << storage_class << "\n";
}
internal_assert(src_id != SpvInvalidId);
SpvId value_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::load(storage_type_id, value_id, src_id));
// if the value type doesn't match the base for the pointer type, cast it accordingly
SpvId result_id = value_id;
if (storage_type != value_type) {
result_id = cast_type(value_type, storage_type, result_id);
}
builder.update_id(result_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::load_from_vector_index(const Load *op, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::load_from_vector_index(): "
<< "variable_id=" << variable_id << " "
<< "value_type=" << value_type << " "
<< "storage_type=" << storage_type << " "
<< "storage_class=" << storage_class << "\n";
internal_assert(op->index.type().is_vector());
// If the runtime array is a vector type, then attempt to do a
// dense vector load by using the base of the ramp divided by
// the number of lanes.
StorageAccessMap::const_iterator it = storage_access_map.find(variable_id);
if (it != storage_access_map.end()) {
storage_type = it->second.storage_type; // use the storage type for the runtime array
SpvId storage_type_id = it->second.storage_type_id;
if (builder.is_vector_type(storage_type_id)) {
Expr ramp_base = strided_ramp_base(op->index);
if (ramp_base.defined()) {
Expr ramp_index = (ramp_base / op->type.lanes());
ramp_index.accept(this);
SpvId index_id = builder.current_id();
load_from_scalar_index(op, index_id, variable_id, value_type, storage_type, storage_class);
return;
}
}
}
op->index.accept(this);
SpvId index_id = builder.current_id();
// Gather vector elements.
SpvFactory::Components loaded_values;
Type scalar_value_type = value_type.with_lanes(1);
SpvFactory::Components index_components = split_vector(op->index.type(), index_id);
for (SpvId scalar_index : index_components) {
load_from_scalar_index(op, scalar_index, variable_id, scalar_value_type, storage_type, storage_class);
SpvId value_component_id = builder.current_id();
loaded_values.push_back(value_component_id);
}
// Create a composite vector from the individual loads
if (loaded_values.size() > 1) {
SpvId result_id = join_vector(value_type, loaded_values);
builder.update_id(result_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::store_at_scalar_index(const Store *op, SpvId index_id, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class, SpvId value_id) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::store_at_scalar_index(): "
<< "index_id=" << index_id << " "
<< "variable_id=" << variable_id << " "
<< "value_type=" << value_type << " "
<< "storage_type=" << storage_type << " "
<< "storage_class=" << storage_class << " "
<< "value_id=" << value_id << "\n";
// determine the base type id for the source value
SpvId base_type_id = builder.type_of(variable_id);
if (builder.is_pointer_type(base_type_id)) {
base_type_id = builder.lookup_base_type(base_type_id);
}
uint32_t zero = 0;
SpvId dst_id = SpvInvalidId;
SpvId dst_index_id = index_id;
SpvId ptr_type_id = builder.declare_pointer_type(storage_type, storage_class);
if (storage_class == SpvStorageClassUniform) {
if (builder.is_struct_type(base_type_id)) {
SpvId zero_id = builder.declare_constant(UInt(32), &zero);
SpvFactory::Indices access_indices = {zero_id, dst_index_id};
dst_id = builder.declare_access_chain(ptr_type_id, variable_id, access_indices);
} else {
SpvFactory::Indices access_indices = {dst_index_id};
dst_id = builder.declare_access_chain(ptr_type_id, variable_id, access_indices);
}
} else if ((storage_class == SpvStorageClassWorkgroup) || (storage_class == SpvStorageClassFunction)) {
if (builder.is_array_type(base_type_id)) {
SpvFactory::Indices access_indices = {dst_index_id};
dst_id = builder.declare_access_chain(ptr_type_id, variable_id, access_indices);
} else {
dst_id = variable_id;
}
} else {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Store): unhandled storage class encountered on op: " << storage_class << "\n";
}
internal_assert(dst_id != SpvInvalidId);
// if the value type doesn't match the base for the pointer type, cast it accordingly
if (storage_type != value_type) {
value_id = cast_type(storage_type, value_type, value_id);
}
builder.append(SpvFactory::store(dst_id, value_id));
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::store_at_vector_index(const Store *op, SpvId variable_id, Type value_type, Type storage_type, SpvStorageClass storage_class, SpvId value_id) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::store_at_vector_index(): "
<< "variable_id=" << variable_id << " "
<< "value_type=" << value_type << " "
<< "storage_type=" << storage_type << " "
<< "storage_class=" << storage_class << "\n";
internal_assert(op->index.type().is_vector());
// If the runtime array is a vector type, then attempt to do a
// dense vector store by using the base of the ramp divided by
// the number of lanes.
StorageAccessMap::const_iterator it = storage_access_map.find(variable_id);
if (it != storage_access_map.end()) {
storage_type = it->second.storage_type;
SpvId storage_type_id = it->second.storage_type_id;
if (builder.is_vector_type(storage_type_id)) {
Expr ramp_base = strided_ramp_base(op->index);
if (ramp_base.defined()) {
Expr ramp_index = (ramp_base / op->value.type().lanes());
ramp_index.accept(this);
SpvId index_id = builder.current_id();
store_at_scalar_index(op, index_id, variable_id, value_type, storage_type, storage_class, value_id);
return;
}
}
}
op->index.accept(this);
SpvId index_id = builder.current_id();
// Split vector value into components
internal_assert(op->index.type().lanes() <= op->value.type().lanes());
SpvFactory::Components value_components = split_vector(op->value.type(), value_id);
SpvFactory::Components index_components = split_vector(op->index.type(), index_id);
// Scatter vector elements.
Type scalar_value_type = op->value.type().with_lanes(1);
for (uint32_t i = 0; i < index_components.size(); i++) {
SpvId index_component_id = index_components[i];
SpvId value_component_id = value_components[i];
store_at_scalar_index(op, index_component_id, variable_id, scalar_value_type, storage_type, storage_class, value_component_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Load *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Load): " << op->type << " " << op->name << "[" << op->index << "]\n";
user_assert(is_const_one(op->predicate)) << "Predicated loads not supported by SPIR-V codegen\n";
// Construct the pointer to read from
internal_assert(symbol_table.contains(op->name));
SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name);
SpvId variable_id = id_and_storage_class.first;
SpvStorageClass storage_class = id_and_storage_class.second;
internal_assert(variable_id != SpvInvalidId);
internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax));
// If this is a load from a buffer block (mapped to a halide buffer) or
// GPU shared memory, the pointer type must match the declared storage
// type for the runtime array.
Type value_type = op->type;
Type storage_type = value_type;
StorageAccessMap::const_iterator it = storage_access_map.find(variable_id);
if (it != storage_access_map.end()) {
storage_type = it->second.storage_type;
}
debug(2) << " value_type=" << op->type << " storage_type=" << storage_type << "\n";
debug(2) << " index_type=" << op->index.type() << " index=" << op->index << "\n";
if (op->index.type().is_scalar()) {
op->index.accept(this);
SpvId index_id = builder.current_id();
load_from_scalar_index(op, index_id, variable_id, value_type, storage_type, storage_class);
} else {
load_from_vector_index(op, variable_id, value_type, storage_type, storage_class);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Store *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Store): " << op->name << "[" << op->index << "] = (" << op->value << ")\n";
user_assert(is_const_one(op->predicate)) << "Predicated stores not supported by SPIR-V codegen!\n";
debug(2) << " value_type=" << op->value.type() << " value=" << op->value << "\n";
op->value.accept(this);
SpvId value_id = builder.current_id();
internal_assert(symbol_table.contains(op->name));
SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name);
SpvId variable_id = id_and_storage_class.first;
SpvStorageClass storage_class = id_and_storage_class.second;
internal_assert(variable_id != SpvInvalidId);
internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax));
Type value_type = op->value.type();
Type storage_type = value_type;
// If this is a store to a buffer block (mapped to a halide buffer) or
// GPU shared memory, the pointer type must match the declared storage
// type for the runtime array
StorageAccessMap::const_iterator it = storage_access_map.find(variable_id);
if (it != storage_access_map.end()) {
storage_type = it->second.storage_type;
}
debug(2) << " value_type=" << value_type << " storage_type=" << storage_type << "\n";
debug(2) << " index_type=" << op->index.type() << " index=" << op->index << "\n";
if (op->index.type().is_scalar()) {
op->index.accept(this);
SpvId index_id = builder.current_id();
store_at_scalar_index(op, index_id, variable_id, value_type, storage_type, storage_class, value_id);
} else {
store_at_vector_index(op, variable_id, value_type, storage_type, storage_class, value_id);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Let *let) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Let): " << (Expr)let << "\n";
let->value.accept(this);
SpvId current_id = builder.current_id();
ScopedSymbolBinding binding(symbol_table, let->name, {current_id, SpvStorageClassFunction});
let->body.accept(this);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const LetStmt *let) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(LetStmt): " << let->name << "\n";
let->value.accept(this);
SpvId current_id = builder.current_id();
ScopedSymbolBinding binding(symbol_table, let->name, {current_id, SpvStorageClassFunction});
let->body.accept(this);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const AssertStmt *stmt) {
// TODO: Fill this in.
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(AssertStmt): "
<< "condition=" << stmt->condition << " "
<< "message=" << stmt->message << "\n";
}
namespace {
std::pair<std::string, uint32_t> simt_intrinsic(const std::string &name) {
if (ends_with(name, ".__thread_id_x")) {
return {"LocalInvocationId", 0};
} else if (ends_with(name, ".__thread_id_y")) {
return {"LocalInvocationId", 1};
} else if (ends_with(name, ".__thread_id_z")) {
return {"LocalInvocationId", 2};
} else if (ends_with(name, ".__block_id_x")) {
return {"WorkgroupId", 0};
} else if (ends_with(name, ".__block_id_y")) {
return {"WorkgroupId", 1};
} else if (ends_with(name, ".__block_id_z")) {
return {"WorkgroupId", 2};
} else if (ends_with(name, "id_w")) {
user_error << "Vulkan only supports <=3 dimensions for gpu blocks";
}
internal_error << "simt_intrinsic called on bad variable name: " << name << "\n";
return {"", -1};
}
} // anonymous namespace
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(For): name=" << op->name << " min=" << op->min << " extent=" << op->extent << "\n";
if (is_gpu_var(op->name)) {
internal_assert((op->for_type == ForType::GPUBlock) ||
(op->for_type == ForType::GPUThread))
<< "kernel loops must be either gpu block or gpu thread\n";
// This should always be true at this point in codegen
internal_assert(is_const_zero(op->min));
auto intrinsic = simt_intrinsic(op->name);
const std::string intrinsic_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + intrinsic.first;
// Intrinsics are inserted when adding the kernel
internal_assert(symbol_table.contains(intrinsic_var_name));
SpvId intrinsic_id = symbol_table.get(intrinsic_var_name).first;
SpvStorageClass storage_class = symbol_table.get(intrinsic_var_name).second;
// extract and cast to the extent type (which is what's expected by Halide's for loops)
Type unsigned_type = UInt(32);
SpvId unsigned_type_id = builder.declare_type(unsigned_type);
SpvId unsigned_value_id = builder.reserve_id(SpvResultId);
SpvFactory::Indices indices = {intrinsic.second};
builder.append(SpvFactory::composite_extract(unsigned_type_id, unsigned_value_id, intrinsic_id, indices));
SpvId intrinsic_value_id = cast_type(op->min.type(), unsigned_type, unsigned_value_id);
{
ScopedSymbolBinding binding(symbol_table, op->name, {intrinsic_value_id, storage_class});
op->body.accept(this);
}
} else {
debug(2) << " (serial for loop): min=" << op->min << " extent=" << op->extent << "\n";
internal_assert(op->for_type == ForType::Serial) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit unhandled For type: " << op->for_type << "\n";
user_assert(op->min.type() == op->extent.type());
user_assert(op->min.type().is_int() || op->min.type().is_uint());
op->min.accept(this);
SpvId min_id = builder.current_id();
op->extent.accept(this);
SpvId extent_id = builder.current_id();
// Compute max.
Type index_type = op->min.type();
SpvId index_type_id = builder.declare_type(index_type);
SpvStorageClass storage_class = SpvStorageClassFunction;
SpvId index_var_type_id = builder.declare_pointer_type(index_type_id, storage_class);
SpvId max_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::integer_add(index_type_id, max_id, min_id, extent_id));
// Declare loop var
const std::string loop_var_name = unique_name(std::string("k") + std::to_string(kernel_index) + "_loop_idx");
debug(2) << " loop_index=" << loop_var_name << " type=" << index_type << "\n";
SpvId loop_var_id = builder.declare_variable(loop_var_name, index_var_type_id, storage_class);
symbol_table.push(loop_var_name, {loop_var_id, storage_class});
SpvId header_block_id = builder.reserve_id(SpvBlockId);
SpvId top_block_id = builder.reserve_id(SpvBlockId);
SpvId body_block_id = builder.reserve_id(SpvBlockId);
SpvId continue_block_id = builder.reserve_id(SpvBlockId);
SpvId merge_block_id = builder.reserve_id(SpvBlockId);
builder.append(SpvFactory::store(loop_var_id, min_id));
SpvBlock header_block = builder.create_block(header_block_id);
builder.enter_block(header_block);
{
builder.append(SpvFactory::loop_merge(merge_block_id, continue_block_id, SpvLoopControlDontUnrollMask));
builder.append(SpvFactory::branch(top_block_id));
}
builder.leave_block();
SpvId loop_index_id = builder.reserve_id(SpvResultId);
SpvBlock top_block = builder.create_block(top_block_id);
builder.enter_block(top_block);
{
SpvId loop_test_type_id = builder.declare_type(Bool());
SpvId loop_test_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::load(index_type_id, loop_index_id, loop_var_id));
builder.append(SpvFactory::integer_less_than(loop_test_type_id, loop_test_id, loop_index_id, max_id, index_type.is_uint()));
builder.append(SpvFactory::conditional_branch(loop_test_id, body_block_id, merge_block_id));
}
builder.leave_block();
SpvBlock body_block = builder.create_block(body_block_id);
builder.enter_block(body_block);
{
ScopedSymbolBinding binding(symbol_table, op->name, {loop_index_id, storage_class});
op->body.accept(this);
builder.append(SpvFactory::branch(continue_block_id));
}
builder.leave_block();
SpvBlock continue_block = builder.create_block(continue_block_id);
builder.enter_block(continue_block);
{
// Update loop variable
int32_t one = 1;
SpvId next_index_id = builder.reserve_id(SpvResultId);
SpvId constant_one_id = builder.declare_constant(index_type, &one);
SpvId current_index_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::load(index_type_id, current_index_id, loop_var_id));
builder.append(SpvFactory::integer_add(index_type_id, next_index_id, current_index_id, constant_one_id));
builder.append(SpvFactory::store(loop_var_id, next_index_id));
builder.append(SpvFactory::branch(header_block_id));
}
builder.leave_block();
symbol_table.pop(loop_var_name);
SpvBlock merge_block = builder.create_block(merge_block_id);
builder.enter_block(merge_block);
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Ramp *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Ramp): "
<< "base=" << op->base << " "
<< "stride=" << op->stride << " "
<< "lanes=" << (uint32_t)op->lanes << "\n";
// TODO: Is there a way to do this that doesn't require duplicating lane values?
SpvId base_type_id = builder.declare_type(op->base.type());
SpvId type_id = builder.declare_type(op->type);
op->base.accept(this);
SpvId base_id = builder.current_id();
op->stride.accept(this);
SpvId stride_id = builder.current_id();
// Generate adds to make the elements of the ramp.
SpvId prev_id = base_id;
SpvFactory::Components constituents = {base_id};
for (int i = 1; i < op->lanes; i++) {
SpvId this_id = builder.reserve_id(SpvResultId);
if (op->base.type().is_float()) {
builder.append(SpvFactory::float_add(base_type_id, this_id, prev_id, stride_id));
} else if (op->base.type().is_int_or_uint()) {
builder.append(SpvFactory::integer_add(base_type_id, this_id, prev_id, stride_id));
} else {
internal_error << "SPIRV: Unhandled base type encountered in ramp!\n";
}
constituents.push_back(this_id);
prev_id = this_id;
}
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::composite_construct(type_id, result_id, constituents));
builder.update_id(result_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Broadcast *op) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Broadcast): "
<< "type=" << op->type << " "
<< "value=" << op->value << "\n";
// TODO: Is there a way to do this that doesn't require duplicating lane values?
SpvId type_id = builder.declare_type(op->type);
op->value.accept(this);
SpvId value_id = builder.current_id();
SpvId result_id = builder.reserve_id(SpvResultId);
SpvFactory::Components constituents;
constituents.insert(constituents.end(), op->lanes, value_id);
builder.append(SpvFactory::composite_construct(type_id, result_id, constituents));
builder.update_id(result_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Provide *) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Provide *): Provide encountered during codegen\n";
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Allocate *op) {
SpvId storage_type_id = builder.declare_type(op->type);
SpvId array_type_id = SpvInvalidId;
SpvId variable_id = SpvInvalidId;
uint32_t array_size = 0;
SpvStorageClass storage_class = SpvStorageClassGeneric;
if (op->memory_type == MemoryType::GPUShared) {
// Allocation of shared memory must be declared at global scope
storage_class = SpvStorageClassWorkgroup; // shared across workgroup
std::string variable_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + op->name;
uint32_t type_size = op->type.bytes();
uint32_t constant_id = 0;
// static fixed size allocation
if (op->extents.size() == 1 && is_const(op->extents[0])) {
array_size = op->constant_allocation_size();
array_type_id = builder.declare_type(op->type, array_size);
builder.add_symbol(variable_name + "_array_type", array_type_id, builder.current_module().id());
debug(2) << "Vulkan: Allocate (fixed-size) " << op->name << " type=" << op->type << " array_size=" << (uint32_t)array_size << " in shared memory on device in global scope\n";
} else {
// dynamic allocation with unknown size at compile time ...
// declare the array size as a specialization constant (which will get overridden at runtime)
Type array_size_type = UInt(32);
array_size = std::max(workgroup_size[0], uint32_t(1)); // use one item per workgroup as an initial guess
SpvId array_size_id = builder.declare_specialization_constant(array_size_type, &array_size);
array_type_id = builder.add_array_with_default_size(storage_type_id, array_size_id);
builder.add_symbol(variable_name + "_array_type", array_type_id, builder.current_module().id());
debug(2) << "Vulkan: Allocate (dynamic size) " << op->name << " type=" << op->type << " default_size=" << (uint32_t)array_size << " in shared memory on device in global scope\n";
// bind the specialization constant to the next slot
std::string constant_name = variable_name + "_array_size";
constant_id = (uint32_t)(descriptor_set_table.back().specialization_constants.size() + 1);
SpvBuilder::Literals spec_id = {constant_id};
builder.add_annotation(array_size_id, SpvDecorationSpecId, spec_id);
builder.add_symbol(constant_name, array_size_id, builder.current_module().id());
// update the descriptor set with the specialization binding
SpecializationBinding spec_binding = {constant_id, (uint32_t)array_size_type.bytes(), constant_name};
descriptor_set_table.back().specialization_constants.push_back(spec_binding);
}
// add the shared memory allocation to the descriptor set
SharedMemoryAllocation shared_mem_allocation = {constant_id, array_size, type_size, variable_name};
descriptor_set_table.back().shared_memory_usage.push_back(shared_mem_allocation);
// declare the variable
SpvId ptr_type_id = builder.declare_pointer_type(array_type_id, storage_class);
variable_id = builder.declare_global_variable(variable_name, ptr_type_id, storage_class);
} else {
// Allocation is not a shared memory allocation, just make a local declaration.
array_size = op->constant_allocation_size();
// It must have a constant size.
user_assert(array_size > 0)
<< "Allocation " << op->name << " has a dynamic size. "
<< "Only fixed-size local allocations are supported with Vulkan.";
debug(2) << "Vulkan: Allocate " << op->name << " type=" << op->type << " size=" << (uint32_t)array_size << " on device in function scope\n";
array_type_id = builder.declare_type(op->type, array_size);
storage_class = SpvStorageClassFunction; // function scope
std::string variable_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + op->name;
SpvId ptr_type_id = builder.declare_pointer_type(array_type_id, storage_class);
variable_id = builder.declare_variable(variable_name, ptr_type_id, storage_class);
}
StorageAccess access;
access.storage_class = storage_class;
access.storage_array_size = array_size;
access.storage_type_id = storage_type_id;
access.storage_type = op->type;
storage_access_map[variable_id] = access;
debug(3) << "Vulkan: Pushing allocation called " << op->name << " onto the symbol table\n";
symbol_table.push(op->name, {variable_id, storage_class});
op->body.accept(this);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Free *op) {
debug(3) << "Vulkan: Popping allocation called " << op->name << " off the symbol table\n";
internal_assert(symbol_table.contains(op->name));
SpvId variable_id = symbol_table.get(op->name).first;
storage_access_map.erase(variable_id);
symbol_table.pop(op->name);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Realize *) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Realize *): Realize encountered during codegen\n";
}
template<typename StmtOrExpr>
SpvFactory::BlockVariables
CodeGen_Vulkan_Dev::SPIRV_Emitter::emit_if_then_else(const Expr &condition,
StmtOrExpr then_case, StmtOrExpr else_case) {
SpvId merge_block_id = builder.reserve_id(SpvBlockId);
SpvId if_block_id = builder.reserve_id(SpvBlockId);
SpvId then_block_id = builder.reserve_id(SpvBlockId);
SpvId else_block_id = else_case.defined() ? builder.reserve_id(SpvBlockId) : merge_block_id;
SpvFactory::BlockVariables block_vars;
// If block
debug(2) << "Vulkan: If => (" << condition << " )\n";
SpvBlock if_block = builder.create_block(if_block_id);
builder.enter_block(if_block);
{
condition.accept(this);
SpvId cond_id = builder.current_id();
builder.append(SpvFactory::selection_merge(merge_block_id, SpvSelectionControlMaskNone));
builder.append(SpvFactory::conditional_branch(cond_id, then_block_id, else_block_id));
}
builder.leave_block();
// Then block
debug(2) << "Vulkan: Then =>\n"
<< then_case << "\n";
SpvBlock then_block = builder.create_block(then_block_id);
builder.enter_block(then_block);
{
then_case.accept(this);
SpvId then_id = builder.current_id();
builder.append(SpvFactory::branch(merge_block_id));
block_vars.emplace_back(then_id, then_block_id);
}
builder.leave_block();
// Else block (optional)
if (else_case.defined()) {
debug(2) << "Vulkan: Else =>\n"
<< else_case << "\n";
SpvBlock else_block = builder.create_block(else_block_id);
builder.enter_block(else_block);
{
else_case.accept(this);
SpvId else_id = builder.current_id();
builder.append(SpvFactory::branch(merge_block_id));
block_vars.emplace_back(else_id, else_block_id);
}
builder.leave_block();
}
// Merge block
SpvBlock merge_block = builder.create_block(merge_block_id);
builder.enter_block(merge_block);
return block_vars;
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const IfThenElse *op) {
if (!builder.current_function().is_defined()) {
user_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const IfThenElse *op): No active function for building!!\n";
}
emit_if_then_else(op->condition, op->then_case, op->else_case);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Evaluate *op) {
op->value.accept(this);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) {
std::cout << " CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Shuffle): "
<< "type=" << op->type << " "
<< "vectors=" << (uint32_t)op->vectors.size() << " "
<< "is_interleave=" << (op->is_interleave() ? "true" : "false") << " "
<< "is_extract_element=" << (op->is_extract_element() ? "true" : "false") << "\n";
// Traverse all the arg vectors
uint32_t arg_idx = 0;
SpvFactory::Operands arg_ids;
arg_ids.reserve(op->vectors.size());
for (const Expr &e : op->vectors) {
debug(2) << " CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Shuffle): Arg[" << arg_idx++ << "] => " << e << "\n";
e.accept(this);
arg_ids.push_back(builder.current_id());
}
if (op->is_interleave()) {
int op_lanes = op->type.lanes();
internal_assert(!arg_ids.empty());
int arg_lanes = op->vectors[0].type().lanes();
std::cout << " vector interleave x" << (uint32_t)op->vectors.size() << " : ";
for (int idx : op->indices) {
std::cout << idx << " ";
}
std::cout << "\n";
if (arg_ids.size() == 1) {
// 1 argument, just do a simple assignment via a cast
SpvId result_id = cast_type(op->type, op->vectors[0].type(), arg_ids[0]);
builder.update_id(result_id);
} else if (arg_ids.size() == 2) {
// 2 arguments, use a composite insert to update even and odd indices
uint32_t even_idx = 0;
uint32_t odd_idx = 1;
SpvFactory::Indices even_indices;
SpvFactory::Indices odd_indices;
for (int i = 0; i < op_lanes; ++i) {
even_indices.push_back(even_idx);
odd_indices.push_back(odd_idx);
even_idx += 2;
odd_idx += 2;
}
SpvId type_id = builder.declare_type(op->type);
SpvId value_id = builder.declare_null_constant(op->type);
SpvId partial_id = builder.reserve_id(SpvResultId);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::composite_insert(type_id, partial_id, arg_ids[0], value_id, even_indices));
builder.append(SpvFactory::composite_insert(type_id, result_id, arg_ids[1], partial_id, odd_indices));
builder.update_id(result_id);
} else {
// 3+ arguments, shuffle via a vector literal
// selecting the appropriate elements of the vectors
int num_vectors = (int)op->vectors.size();
std::vector<SpvFactory::Components> vector_component_ids(num_vectors);
for (uint32_t i = 0; i < (uint32_t)arg_ids.size(); ++i) {
if (op->vectors[i].type().is_vector()) {
vector_component_ids[i] = split_vector(op->vectors[i].type(), arg_ids[i]);
} else {
vector_component_ids[i] = {arg_ids[i]};
}
}
SpvFactory::Components result_component_ids(op_lanes);
for (int i = 0; i < op_lanes; i++) {
int arg = i % num_vectors;
int arg_idx = i / num_vectors;
internal_assert(arg_idx <= arg_lanes);
result_component_ids[i] = vector_component_ids[arg][arg_idx];
}
SpvId result_id = join_vector(op->type, result_component_ids);
builder.update_id(result_id);
}
} else if (op->is_extract_element()) {
int idx = op->indices[0];
internal_assert(idx >= 0);
internal_assert(idx <= op->vectors[0].type().lanes());
if (op->vectors[0].type().is_vector()) {
SpvFactory::Indices indices = {(uint32_t)idx};
SpvId type_id = builder.declare_type(op->type);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::composite_extract(type_id, result_id, arg_ids[0], indices));
builder.update_id(result_id);
} else {
SpvId result_id = cast_type(op->type, op->vectors[0].type(), arg_ids[0]);
builder.update_id(result_id);
}
} else if (op->type.is_scalar()) {
// Deduce which vector we need. Apparently it's not required
// that all vectors have identical lanes, so a loop is required.
// Since idx of -1 means "don't care", we'll treat it as 0 to simplify.
SpvId result_id = SpvInvalidId;
int idx = std::max(0, op->indices[0]);
for (size_t vec_idx = 0; vec_idx < op->vectors.size(); vec_idx++) {
const int vec_lanes = op->vectors[vec_idx].type().lanes();
if (idx < vec_lanes) {
if (op->vectors[vec_idx].type().is_vector()) {
SpvFactory::Indices indices = {(uint32_t)idx};
SpvId type_id = builder.declare_type(op->type);
result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::composite_extract(type_id, result_id, arg_ids[vec_idx], indices));
} else {
result_id = arg_ids[vec_idx];
}
break;
}
idx -= vec_lanes;
}
} else {
// vector shuffle ... not interleaving
int op_lanes = op->type.lanes();
int num_vectors = (int)op->vectors.size();
std::cout << " vector shuffle x" << num_vectors << " : ";
for (int idx : op->indices) {
std::cout << idx << " ";
}
std::cout << "\n";
if (num_vectors == 1) {
// 1 argument, just do a simple assignment via a cast
SpvId result_id = cast_type(op->type, op->vectors[0].type(), arg_ids[0]);
builder.update_id(result_id);
} else if (num_vectors == 2) {
// 2 arguments, use the builtin vector shuffle that takes a pair of vectors
SpvFactory::Indices indices;
indices.reserve(op->indices.size());
indices.insert(indices.end(), op->indices.begin(), op->indices.end());
SpvId type_id = builder.declare_type(op->type);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::vector_shuffle(type_id, result_id, arg_ids[0], arg_ids[1], indices));
builder.update_id(result_id);
} else {
std::vector<SpvFactory::Components> vector_component_ids(num_vectors);
for (uint32_t i = 0; i < (uint32_t)arg_ids.size(); ++i) {
if (op->vectors[i].type().is_vector()) {
vector_component_ids[i] = split_vector(op->vectors[i].type(), arg_ids[i]);
} else {
vector_component_ids[i] = {arg_ids[i]};
}
}
SpvFactory::Components result_component_ids(op_lanes);
for (int i = 0; i < op_lanes && i < (int)op->indices.size(); i++) {
int idx = op->indices[i];
int arg = idx % num_vectors;
int arg_idx = idx / num_vectors;
internal_assert(arg_idx <= (int)vector_component_ids[arg].size());
result_component_ids[i] = vector_component_ids[arg][arg_idx];
}
SpvId result_id = join_vector(op->type, result_component_ids);
builder.update_id(result_id);
}
}
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const VectorReduce *) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const VectorReduce *): VectorReduce not implemented for codegen\n";
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Prefetch *) {
internal_error << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Prefetch *): Prefetch not implemented for codegen\n";
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Fork *) {
internal_error << "void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Fork *): Fork not implemented for codegen";
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Acquire *) {
internal_error << "void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Acquire *): Acquire not implemented for codegen";
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Atomic *) {
internal_error << "void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Atomic *): Atomic not implemented for codegen";
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit_unary_op(SpvOp op_code, Type t, const Expr &a) {
SpvId type_id = builder.declare_type(t);
a.accept(this);
SpvId src_a_id = builder.current_id();
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::unary_op(op_code, type_id, result_id, src_a_id));
builder.update_id(result_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit_binary_op(SpvOp op_code, Type t, const Expr &a, const Expr &b) {
SpvId type_id = builder.declare_type(t);
a.accept(this);
SpvId src_a_id = builder.current_id();
b.accept(this);
SpvId src_b_id = builder.current_id();
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::binary_op(op_code, type_id, result_id, src_a_id, src_b_id));
builder.update_id(result_id);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit_glsl_op(SpvId glsl_op_code, Type type, const std::vector<Expr> &args) {
SpvId type_id = builder.declare_type(type);
SpvFactory::Operands operands;
operands.reserve(args.size());
for (const Expr &e : args) {
e.accept(this);
SpvId arg_value_id = builder.current_id();
if (builder.type_of(arg_value_id) != type_id) {
SpvId casted_value_id = cast_type(type, e.type(), arg_value_id); // all GLSL args must match return type
operands.push_back(casted_value_id);
} else {
operands.push_back(arg_value_id);
}
}
// sanity check the expected number of operands
internal_assert(glsl_operand_count(glsl_op_code) == operands.size());
SpvId inst_set_id = builder.import_glsl_intrinsics();
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::extended(inst_set_id, glsl_op_code, type_id, result_id, operands));
builder.update_id(result_id);
}
SpvFactory::Components CodeGen_Vulkan_Dev::SPIRV_Emitter::split_vector(Type type, SpvId value_id) {
SpvFactory::Components value_components;
SpvId scalar_value_type_id = builder.declare_type(type.with_lanes(1));
for (uint32_t i = 0; i < (uint32_t)type.lanes(); i++) {
SpvFactory::Indices extract_indices = {i};
SpvId value_component_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::composite_extract(scalar_value_type_id, value_component_id, value_id, extract_indices));
value_components.push_back(value_component_id);
}
return value_components;
}
SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::join_vector(Type type, const SpvFactory::Components &value_components) {
SpvId type_id = builder.declare_type(type);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::composite_construct(type_id, result_id, value_components));
return result_id;
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::reset() {
kernel_index = 0;
builder.reset();
SymbolScope empty;
symbol_table.swap(empty);
storage_access_map.clear();
descriptor_set_table.clear();
reset_workgroup_size();
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::init_module() {
reset();
if (target.has_feature(Target::VulkanV13)) {
// Encode to SPIR-V v1.2 to allow dynamic dispatching (if needed)
builder.set_version_format(0x00010200);
} else {
// Encode to SPIR-V v1.0 (which is the only format supported by Vulkan v1.0)
builder.set_version_format(0x00010000);
}
// NOTE: Source language is irrelevant. We encode the binary directly
builder.set_source_language(SpvSourceLanguageUnknown);
// TODO: Should we autodetect and/or force 32bit or 64bit?
builder.set_addressing_model(SpvAddressingModelLogical);
// TODO: Should we autodetect the VulkanMemoryModel extension and use that instead?
builder.set_memory_model(SpvMemoryModelGLSL450);
// NOTE: Execution model for Vulkan must be GLCompute which requires Shader support
builder.require_capability(SpvCapabilityShader);
// NOTE: Extensions are handled in finalize
}
namespace {
std::vector<char> encode_header_string(const std::string &str) {
uint32_t padded_word_count = (str.length() / 4) + 1; // add an extra entry to ensure strings are terminated
uint32_t padded_str_length = padded_word_count * 4;
std::vector<char> encoded_string(padded_str_length, '\0');
for (uint32_t c = 0; c < str.length(); c++) {
encoded_string[c] = str[c];
}
return encoded_string;
}
} // namespace
void CodeGen_Vulkan_Dev::SPIRV_Emitter::encode_header(SpvBinary &spirv_header) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::encode_header\n";
// Encode a sidecar for the module that lists the descriptor sets
// corresponding to each entry point contained in the module.
//
// This metadata will be used at runtime to define the shader bindings
// needed for all buffers, constants, shared memory, and workgroup sizes
// that are required for execution.
//
// Like the SPIR-V code module, each entry is one word (1x uint32_t).
// Variable length sections are prefixed with their length (ie number of entries).
//
// [0] Header word count (total length of header)
// [1] Number of descriptor sets
// ... For each descriptor set ...
// ... [0] Length of entry point name (padded to nearest word size)
// ....... [*] Entry point string data (padded with null chars)
// ... [1] Number of uniform buffers for this descriptor set
// ... [2] Number of storage buffers for this descriptor set
// ... [3] Number of specialization constants for this descriptor set
// ....... For each specialization constant ...
// ....... [0] Length of constant name string (padded to nearest word size)
// ........... [*] Constant name string data (padded with null chars)
// ....... [1] Constant id (as used in VkSpecializationMapEntry for binding)
// ....... [2] Size of data type (in bytes)
// ... [4] Number of shared memory allocations for this descriptor set
// ....... For each allocation ...
// ....... [0] Length of variable name string (padded to nearest word size)
// ........... [*] Variable name string data (padded with null chars)
// ....... [1] Constant id to use for overriding array size (zero if it is not bound to a specialization constant)
// ....... [2] Size of data type (in bytes)
// ....... [3] Size of array (ie element count)
// ... [4] Dynamic workgroup dimensions bound to specialization constants
// ....... [0] Constant id to use for local_size_x (zero if it was statically declared and not bound to a specialization constant)
// ....... [1] Constant id to use for local_size_y
// ....... [2] Constant id ot use for local_size_z
//
// NOTE: Halide's Vulkan runtime consumes this header prior to compiling.
//
// Both vk_decode_shader_bindings() and vk_compile_shader_module() will
// need to be updated if the header encoding ever changes!
//
uint32_t index = 0;
spirv_header.push_back(descriptor_set_table.size());
for (const DescriptorSet &ds : descriptor_set_table) {
// encode the entry point name into an array of chars (padded to the next word entry)
std::vector<char> entry_point_name = encode_header_string(ds.entry_point_name);
uint32_t entry_point_name_entries = (uint32_t)(entry_point_name.size() / sizeof(uint32_t));
debug(2) << " [" << index << "] "
<< "uniform_buffer_count=" << ds.uniform_buffer_count << " "
<< "storage_buffer_count=" << ds.storage_buffer_count << " "
<< "entry_point_name_size=" << entry_point_name.size() << " "
<< "entry_point_name: " << (const char *)entry_point_name.data() << "\n";
// [0] Length of entry point name (padded to nearest word size)
spirv_header.push_back(entry_point_name_entries);
// [*] Entry point string data (padded with null chars)
spirv_header.insert(spirv_header.end(), (const uint32_t *)entry_point_name.data(), (const uint32_t *)(entry_point_name.data() + entry_point_name.size()));
// [1] Number of uniform buffers for this descriptor set
spirv_header.push_back(ds.uniform_buffer_count);
// [2] Number of storage buffers for this descriptor set
spirv_header.push_back(ds.storage_buffer_count);
// [3] Number of specialization constants for this descriptor set
spirv_header.push_back((uint32_t)ds.specialization_constants.size());
debug(2) << " specialization_count=" << (uint32_t)ds.specialization_constants.size() << "\n";
// For each specialization constant ...
for (const SpecializationBinding &spec_binding : ds.specialization_constants) {
// encode the constant name into an array of chars (padded to the next word entry)
std::vector<char> constant_name = encode_header_string(spec_binding.constant_name);
uint32_t constant_name_entries = (uint32_t)(constant_name.size() / sizeof(uint32_t));
debug(2) << " [" << spec_binding.constant_id << "] "
<< "constant_name=" << (const char *)constant_name.data() << " "
<< "type_size=" << spec_binding.type_size << "\n";
// [0] Length of constant name string (padded to nearest word size)
spirv_header.push_back(constant_name_entries);
// [*] Constant name string data (padded with null chars)
spirv_header.insert(spirv_header.end(), (const uint32_t *)constant_name.data(), (const uint32_t *)(constant_name.data() + constant_name.size()));
// [1] Constant id (as used in VkSpecializationMapEntry for binding)
spirv_header.push_back(spec_binding.constant_id);
// [2] Size of data type (in bytes)
spirv_header.push_back(spec_binding.type_size);
}
// [4] Number of shared memory allocations for this descriptor set
spirv_header.push_back((uint32_t)ds.shared_memory_usage.size());
debug(2) << " shared_memory_allocations=" << (uint32_t)ds.shared_memory_usage.size() << "\n";
// For each allocation ...
uint32_t shm_index = 0;
for (const SharedMemoryAllocation &shared_mem_alloc : ds.shared_memory_usage) {
// encode the variable name into an array of chars (padded to the next word entry)
std::vector<char> variable_name = encode_header_string(shared_mem_alloc.variable_name);
uint32_t variable_name_entries = (uint32_t)(variable_name.size() / sizeof(uint32_t));
debug(2) << " [" << shm_index++ << "] "
<< "variable_name=" << (const char *)variable_name.data() << " "
<< "constant_id=" << shared_mem_alloc.constant_id << " "
<< "type_size=" << shared_mem_alloc.type_size << " "
<< "array_size=" << shared_mem_alloc.array_size << "\n";
// [0] Length of variable name string (padded to nearest word size)
spirv_header.push_back(variable_name_entries);
// [*] Variable name string data (padded with null chars)
spirv_header.insert(spirv_header.end(), (const uint32_t *)variable_name.data(), (const uint32_t *)(variable_name.data() + variable_name.size()));
// [1] Constant id to use for overriding array size (zero if it is not bound to a specialization constant)
spirv_header.push_back(shared_mem_alloc.constant_id);
// [2] Size of data type (in bytes)
spirv_header.push_back(shared_mem_alloc.type_size);
// [3] Size of array (ie element count)
spirv_header.push_back(shared_mem_alloc.array_size);
}
// [4] Dynamic workgroup dimensions bound to specialization constants
spirv_header.push_back(ds.workgroup_size_binding.local_size_constant_id[0]);
spirv_header.push_back(ds.workgroup_size_binding.local_size_constant_id[1]);
spirv_header.push_back(ds.workgroup_size_binding.local_size_constant_id[2]);
++index;
}
uint32_t header_word_count = spirv_header.size();
spirv_header.insert(spirv_header.begin(), header_word_count + 1);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::reset_workgroup_size() {
workgroup_size[0] = 0;
workgroup_size[1] = 0;
workgroup_size[2] = 0;
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::find_workgroup_size(const Stmt &s) {
reset_workgroup_size();
FindWorkGroupSize fwgs;
s.accept(&fwgs);
workgroup_size[0] = fwgs.workgroup_size[0];
workgroup_size[1] = fwgs.workgroup_size[1];
workgroup_size[2] = fwgs.workgroup_size[2];
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::declare_workgroup_size(SpvId kernel_func_id) {
if (workgroup_size[0] == 0) {
// workgroup size is dynamic ...
if (!target.has_feature(Target::VulkanV13)) {
user_error << "Vulkan: Dynamic workgroup sizes require Vulkan v1.3+ support! "
<< "Either enable the target feature, or adjust the pipeline's schedule "
<< "to use static workgroup sizes!";
}
// declare the workgroup local size as a specialization constant (which will get overridden at runtime)
Type local_size_type = UInt(32);
uint32_t local_size_x = std::max(workgroup_size[0], (uint32_t)1); // use a minimum of 1 for the default value
uint32_t local_size_y = std::max(workgroup_size[1], (uint32_t)1);
uint32_t local_size_z = std::max(workgroup_size[2], (uint32_t)1);
SpvId local_size_x_id = builder.declare_specialization_constant(local_size_type, &local_size_x);
SpvId local_size_y_id = builder.declare_specialization_constant(local_size_type, &local_size_y);
SpvId local_size_z_id = builder.declare_specialization_constant(local_size_type, &local_size_z);
SpvId local_size_ids[3] = {
local_size_x_id,
local_size_y_id,
local_size_z_id};
const char *local_size_names[3] = {
"__thread_id_x",
"__thread_id_y",
"__thread_id_z"};
debug(1) << "Vulkan: Using dynamic workgroup local size with default of [" << local_size_x << ", " << local_size_y << ", " << local_size_z << "]...\n";
// annotate each local size with a corresponding specialization constant
for (uint32_t dim = 0; dim < 3; dim++) {
SpvId constant_id = (uint32_t)(descriptor_set_table.back().specialization_constants.size() + 1);
SpvBuilder::Literals spec_id = {constant_id};
builder.add_annotation(local_size_ids[dim], SpvDecorationSpecId, spec_id);
builder.add_symbol(local_size_names[dim], local_size_ids[dim], builder.current_module().id());
SpecializationBinding spec_binding = {constant_id, (uint32_t)sizeof(uint32_t), local_size_names[dim]};
descriptor_set_table.back().specialization_constants.push_back(spec_binding);
descriptor_set_table.back().workgroup_size_binding.local_size_constant_id[dim] = constant_id;
}
// Add workgroup size to execution mode
SpvInstruction exec_mode_inst = SpvFactory::exec_mode_local_size_id(kernel_func_id, local_size_x_id, local_size_y_id, local_size_z_id);
builder.current_module().add_execution_mode(exec_mode_inst);
} else {
// workgroup size is static ...
workgroup_size[0] = std::max(workgroup_size[0], (uint32_t)1);
workgroup_size[1] = std::max(workgroup_size[1], (uint32_t)1);
workgroup_size[2] = std::max(workgroup_size[2], (uint32_t)1);
debug(1) << "Vulkan: Using static workgroup local size [" << workgroup_size[0] << ", " << workgroup_size[1] << ", " << workgroup_size[2] << "]...\n";
// Add workgroup size to execution mode
SpvInstruction exec_mode_inst = SpvFactory::exec_mode_local_size(kernel_func_id, workgroup_size[0], workgroup_size[1], workgroup_size[2]);
builder.current_module().add_execution_mode(exec_mode_inst);
}
}
namespace {
// Locate all the unique GPU variables used as SIMT intrinsics
class FindIntrinsicsUsed : public IRVisitor {
using IRVisitor::visit;
void visit(const For *op) override {
if (CodeGen_GPU_Dev::is_gpu_var(op->name)) {
auto intrinsic = simt_intrinsic(op->name);
intrinsics_used.insert(intrinsic.first);
}
op->body.accept(this);
}
void visit(const Variable *op) override {
if (CodeGen_GPU_Dev::is_gpu_var(op->name)) {
auto intrinsic = simt_intrinsic(op->name);
intrinsics_used.insert(intrinsic.first);
}
}
public:
std::unordered_set<std::string> intrinsics_used;
FindIntrinsicsUsed() = default;
};
// Map the SPIR-V builtin intrinsic name to its corresponding enum value
SpvBuiltIn map_simt_builtin(const std::string &intrinsic_name) {
if (starts_with(intrinsic_name, "Workgroup")) {
return SpvBuiltInWorkgroupId;
} else if (starts_with(intrinsic_name, "Local")) {
return SpvBuiltInLocalInvocationId;
}
internal_error << "map_simt_builtin called on bad variable name: " << intrinsic_name << "\n";
return SpvBuiltInMax;
}
} // namespace
void CodeGen_Vulkan_Dev::SPIRV_Emitter::declare_entry_point(const Stmt &s, SpvId kernel_func_id) {
// Locate all simt intrinsics
FindIntrinsicsUsed find_intrinsics;
s.accept(&find_intrinsics);
SpvFactory::Variables entry_point_variables;
for (const std::string &intrinsic_name : find_intrinsics.intrinsics_used) {
// The builtins are pointers to vec3
SpvStorageClass storage_class = SpvStorageClassInput;
SpvId intrinsic_type_id = builder.declare_type(Type(Type::UInt, 32, 3));
SpvId intrinsic_ptr_type_id = builder.declare_pointer_type(intrinsic_type_id, storage_class);
const std::string intrinsic_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + intrinsic_name;
SpvId intrinsic_var_id = builder.declare_global_variable(intrinsic_var_name, intrinsic_ptr_type_id, storage_class);
SpvId intrinsic_loaded_id = builder.reserve_id();
builder.append(SpvFactory::load(intrinsic_type_id, intrinsic_loaded_id, intrinsic_var_id));
symbol_table.push(intrinsic_var_name, {intrinsic_loaded_id, storage_class});
// Annotate that this is the specific builtin
SpvBuiltIn built_in_kind = map_simt_builtin(intrinsic_name);
SpvBuilder::Literals annotation_literals = {(uint32_t)built_in_kind};
builder.add_annotation(intrinsic_var_id, SpvDecorationBuiltIn, annotation_literals);
// Add the builtin to the interface
entry_point_variables.push_back(intrinsic_var_id);
}
// Add the entry point with the appropriate execution model
// NOTE: exec_model must be GLCompute to work with Vulkan ... Kernel is only supported in OpenCL
builder.add_entry_point(kernel_func_id, SpvExecutionModelGLCompute, entry_point_variables);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::declare_device_args(const Stmt &s, uint32_t entry_point_index,
const std::string &entry_point_name,
const std::vector<DeviceArgument> &args) {
// Keep track of the descriptor set needed to bind this kernel's inputs / outputs
DescriptorSet descriptor_set;
descriptor_set.entry_point_name = entry_point_name;
// Add required extension support for storage types which are necessary to
// use smaller bit-width types for any halide buffer *or* device argument
// (passed as a runtime array)
for (const auto &arg : args) {
if (arg.type.is_int_or_uint()) {
if (arg.type.bits() == 8) {
builder.require_extension("SPV_KHR_8bit_storage");
} else if (arg.type.bits() == 16) {
builder.require_extension("SPV_KHR_16bit_storage");
}
}
}
// GLSL-style: each input buffer is a runtime array in a buffer struct
// All other params get passed in as a single uniform block
// First, need to count scalar parameters to construct the uniform struct
SpvBuilder::StructMemberTypes param_struct_members;
for (const auto &arg : args) {
if (!arg.is_buffer) {
// Add required access capability for smaller bit-width types used as runtime arrays
if (arg.type.bits() == 8) {
builder.require_capability(SpvCapabilityUniformAndStorageBuffer8BitAccess);
} else if (arg.type.bits() == 16) {
builder.require_capability(SpvCapabilityUniformAndStorageBuffer16BitAccess);
}
SpvId arg_type_id = builder.declare_type(arg.type);
param_struct_members.push_back(arg_type_id);
}
}
// Add a binding for a uniform buffer packed with all scalar args
uint32_t binding_counter = 0;
if (!param_struct_members.empty()) {
const std::string struct_name = std::string("k") + std::to_string(kernel_index) + std::string("_args_struct");
SpvId param_struct_type_id = builder.declare_struct(struct_name, param_struct_members);
// Add a decoration describing the offset for each parameter struct member
uint32_t param_member_index = 0;
uint32_t param_member_offset = 0;
for (const auto &arg : args) {
if (!arg.is_buffer) {
SpvBuilder::Literals param_offset_literals = {param_member_offset};
builder.add_struct_annotation(param_struct_type_id, param_member_index, SpvDecorationOffset, param_offset_literals);
param_member_offset += arg.type.bytes();
param_member_index++;
}
}
// Add a Block decoration for the parameter pack itself
builder.add_annotation(param_struct_type_id, SpvDecorationBlock);
// Add a variable for the parameter pack
const std::string param_pack_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_args_var");
SpvId param_pack_ptr_type_id = builder.declare_pointer_type(param_struct_type_id, SpvStorageClassUniform);
SpvId param_pack_var_id = builder.declare_global_variable(param_pack_var_name, param_pack_ptr_type_id, SpvStorageClassUniform);
// We always pass in the parameter pack as the first binding
SpvBuilder::Literals binding_index = {0};
SpvBuilder::Literals dset_index = {entry_point_index};
builder.add_annotation(param_pack_var_id, SpvDecorationDescriptorSet, dset_index);
builder.add_annotation(param_pack_var_id, SpvDecorationBinding, binding_index);
descriptor_set.uniform_buffer_count++;
binding_counter++;
// Declare all the args with appropriate offsets into the parameter struct
uint32_t scalar_index = 0;
for (const auto &arg : args) {
if (!arg.is_buffer) {
SpvId arg_type_id = builder.declare_type(arg.type);
SpvId access_index_id = builder.declare_constant(UInt(32), &scalar_index);
SpvId pointer_type_id = builder.declare_pointer_type(arg_type_id, SpvStorageClassUniform);
SpvFactory::Indices access_indices = {access_index_id};
SpvId access_chain_id = builder.declare_access_chain(pointer_type_id, param_pack_var_id, access_indices);
scalar_index++;
SpvId param_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::load(arg_type_id, param_id, access_chain_id));
symbol_table.push(arg.name, {param_id, SpvStorageClassUniform});
}
}
}
// Add bindings for all device buffers declared as GLSL-style buffer blocks in uniform storage
for (const auto &arg : args) {
if (arg.is_buffer) {
// Check for dense loads & stores to determine the widest vector
// width we can safely index
CheckAlignedDenseVectorLoadStore check_dense(arg.name);
s.accept(&check_dense);
int lanes = check_dense.are_all_dense ? check_dense.lanes : 1;
// Declare the runtime array (which maps directly to the Halide device buffer)
Type array_element_type = arg.type.with_lanes(lanes);
SpvId array_element_type_id = builder.declare_type(array_element_type);
SpvId runtime_arr_type_id = builder.add_runtime_array(array_element_type_id);
// Annotate the array with its stride
SpvBuilder::Literals array_stride = {(uint32_t)(arg.type.bytes())};
builder.add_annotation(runtime_arr_type_id, SpvDecorationArrayStride, array_stride);
// Wrap the runtime array in a struct (required with SPIR-V buffer block semantics)
SpvBuilder::StructMemberTypes struct_member_types = {runtime_arr_type_id};
const std::string struct_name = std::string("k") + std::to_string(kernel_index) + std::string("_buffer_block") + std::to_string(binding_counter);
SpvId struct_type_id = builder.declare_struct(struct_name, struct_member_types);
// Declare a pointer to the struct as a global variable
SpvStorageClass storage_class = SpvStorageClassUniform;
SpvId ptr_struct_type_id = builder.declare_pointer_type(struct_type_id, storage_class);
const std::string buffer_block_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + arg.name;
SpvId buffer_block_var_id = builder.declare_global_variable(buffer_block_var_name, ptr_struct_type_id, storage_class);
// Annotate the struct to indicate it's passed in a GLSL-style buffer block
builder.add_annotation(struct_type_id, SpvDecorationBufferBlock);
// Annotate the offset for the array
SpvBuilder::Literals zero_literal = {uint32_t(0)};
builder.add_struct_annotation(struct_type_id, 0, SpvDecorationOffset, zero_literal);
// Set descriptor set and binding indices
SpvBuilder::Literals dset_index = {entry_point_index};
SpvBuilder::Literals binding_index = {uint32_t(binding_counter++)};
builder.add_annotation(buffer_block_var_id, SpvDecorationDescriptorSet, dset_index);
builder.add_annotation(buffer_block_var_id, SpvDecorationBinding, binding_index);
symbol_table.push(arg.name, {buffer_block_var_id, storage_class});
StorageAccess access;
access.storage_type_id = array_element_type_id;
access.storage_type = array_element_type;
access.storage_class = storage_class;
storage_access_map[buffer_block_var_id] = access;
descriptor_set.storage_buffer_count++;
}
}
// Save the descriptor set (so we can output the binding information as a header to the code module)
descriptor_set_table.push_back(descriptor_set);
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::compile(std::vector<char> &module) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::compile\n";
// First encode the descriptor set bindings for each entry point
// as a sidecar which we will add as a preamble header to the actual
// SPIR-V binary so the runtime can know which descriptor set to use
// for each entry point
SpvBinary spirv_header;
encode_header(spirv_header);
// Finalize the SPIR-V module
builder.finalize();
// Validate the SPIR-V for the target
if (builder.is_capability_required(SpvCapabilityInt8) && !target.has_feature(Target::VulkanInt8)) {
user_error << "Vulkan: Code requires 8-bit integer support (which is not enabled in the target features)! "
<< "Either enable the target feature, or adjust the algorithm to avoid using this data type!";
}
if (builder.is_capability_required(SpvCapabilityInt16) && !target.has_feature(Target::VulkanInt16)) {
user_error << "Vulkan: Code requires 16-bit integer support (which is not enabled in the target features)! "
<< "Either enable the target feature, or adjust the algorithm to avoid using this data type!";
}
if (builder.is_capability_required(SpvCapabilityInt64) && !target.has_feature(Target::VulkanInt64)) {
user_error << "Vulkan: Code requires 64-bit integer support (which is not enabled in the target features)! "
<< "Either enable the target feature, or adjust the algorithm to avoid using this data type!";
}
if (builder.is_capability_required(SpvCapabilityFloat16) && !target.has_feature(Target::VulkanFloat16)) {
user_error << "Vulkan: Code requires 16-bit floating-point support (which is not enabled in the target features)! "
<< "Either enable the target feature, or adjust the algorithm to avoid using this data type!";
}
if (builder.is_capability_required(SpvCapabilityFloat64) && !target.has_feature(Target::VulkanFloat64)) {
user_error << "Vulkan: Code requires 64-bit floating-point support (which is not enabled in the target features)! "
<< "Either enable the target feature, or adjust the algorithm to avoid using this data type!";
}
// Encode the SPIR-V into a compliant binary
SpvBinary spirv_binary;
builder.encode(spirv_binary);
size_t header_bytes = spirv_header.size() * sizeof(uint32_t);
size_t binary_bytes = spirv_binary.size() * sizeof(uint32_t);
debug(2) << " encoding module ("
<< "header_size: " << (uint32_t)(header_bytes) << ", "
<< "binary_size: " << (uint32_t)(binary_bytes) << ")\n";
// Combine the header and binary into the module
module.reserve(header_bytes + binary_bytes);
module.insert(module.end(), (const char *)spirv_header.data(), (const char *)(spirv_header.data() + spirv_header.size()));
module.insert(module.end(), (const char *)spirv_binary.data(), (const char *)(spirv_binary.data() + spirv_binary.size()));
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::add_kernel(const Stmt &s,
const std::string &kernel_name,
const std::vector<DeviceArgument> &args) {
debug(2) << "Adding Vulkan kernel " << kernel_name << "\n";
// Add function definition
// TODO: can we use one of the function control annotations?
// https://github.com/halide/Halide/issues/7533
// Discover the workgroup size
find_workgroup_size(s);
// Update the kernel index for the module
kernel_index++;
// Declare the kernel function
SpvId void_type_id = builder.declare_void_type();
SpvId kernel_func_id = builder.add_function(kernel_name, void_type_id);
SpvFunction kernel_func = builder.lookup_function(kernel_func_id);
uint32_t entry_point_index = builder.current_module().entry_point_count();
builder.enter_function(kernel_func);
// Declare the entry point and input intrinsics for the kernel func
declare_entry_point(s, kernel_func_id);
// Declare all parameters -- scalar args and device buffers
declare_device_args(s, entry_point_index, kernel_name, args);
// Traverse
s.accept(this);
// Insert return statement end delimiter
kernel_func.tail_block().add_instruction(SpvFactory::return_stmt());
// Declare the workgroup size for the kernel
declare_workgroup_size(kernel_func_id);
// Pop scope
for (const auto &arg : args) {
symbol_table.pop(arg.name);
}
builder.leave_block();
builder.leave_function();
storage_access_map.clear();
}
void CodeGen_Vulkan_Dev::SPIRV_Emitter::dump() const {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::dump()\n";
std::cerr << builder.current_module();
}
CodeGen_Vulkan_Dev::CodeGen_Vulkan_Dev(Target t)
: emitter(t) {
// Empty
}
void CodeGen_Vulkan_Dev::init_module() {
debug(2) << "CodeGen_Vulkan_Dev::init_module\n";
emitter.init_module();
}
void CodeGen_Vulkan_Dev::add_kernel(Stmt stmt,
const std::string &name,
const std::vector<DeviceArgument> &args) {
debug(2) << "CodeGen_Vulkan_Dev::add_kernel " << name << "\n";
// We need to scalarize/de-predicate any loads/stores, since Vulkan does not support predication.
stmt = scalarize_predicated_loads_stores(stmt);
debug(2) << "CodeGen_Vulkan_Dev: after removing predication: \n"
<< stmt;
current_kernel_name = name;
emitter.add_kernel(stmt, name, args);
// dump the SPIRV file if requested
if (getenv("HL_SPIRV_DUMP_FILE")) {
dump();
}
}
std::vector<char> CodeGen_Vulkan_Dev::compile_to_src() {
debug(2) << "CodeGen_Vulkan_Dev::compile_to_src\n";
std::vector<char> module;
emitter.compile(module);
return module;
}
std::string CodeGen_Vulkan_Dev::get_current_kernel_name() {
return current_kernel_name;
}
std::string CodeGen_Vulkan_Dev::print_gpu_name(const std::string &name) {
return name;
}
void CodeGen_Vulkan_Dev::dump() {
std::vector<char> module = compile_to_src();
// Print the contents of the compiled SPIR-V module
emitter.dump();
// Skip the header and only output the SPIR-V binary
const uint32_t *decode = (const uint32_t *)(module.data());
uint32_t header_word_count = decode[0];
size_t header_size = header_word_count * sizeof(uint32_t);
const uint32_t *binary_ptr = (decode + header_word_count);
size_t binary_size = (module.size() - header_size);
const char *filename = getenv("HL_SPIRV_DUMP_FILE") ? getenv("HL_SPIRV_DUMP_FILE") : "out.spv";
debug(1) << "Vulkan: Dumping SPIRV module to file: '" << filename << "'\n";
std::ofstream f(filename, std::ios::out | std::ios::binary);
f.write((const char *)(binary_ptr), binary_size);
f.close();
}
} // namespace
std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_Vulkan_Dev(const Target &target) {
return std::make_unique<CodeGen_Vulkan_Dev>(target);
}
} // namespace Internal
} // namespace Halide
#else // WITH_SPIRV
namespace Halide {
namespace Internal {
std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_Vulkan_Dev(const Target &target) {
return nullptr;
}
} // namespace Internal
} // namespace Halide
#endif // WITH_SPIRV