CodeGen_WebGPU_Dev.cpp
#include <cmath>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "CodeGen_GPU_Dev.h"
#include "CodeGen_Internal.h"
#include "CodeGen_WebGPU_Dev.h"
#include "IROperator.h"
namespace Halide {
namespace Internal {
using std::ostringstream;
using std::string;
using std::vector;
namespace {
class CodeGen_WebGPU_Dev : public CodeGen_GPU_Dev {
public:
CodeGen_WebGPU_Dev(const Target &target);
/** Compile a GPU kernel into the module. This may be called many times
* with different kernels, which will all be accumulated into a single
* source module shared by a given Halide pipeline. */
void add_kernel(Stmt stmt,
const string &name,
const vector<DeviceArgument> &args) override;
/** (Re)initialize the GPU kernel module. This is separate from compile,
* since a GPU device module will often have many kernels compiled into it
* for a single pipeline. */
void init_module() override;
vector<char> compile_to_src() override;
string get_current_kernel_name() override;
void dump() override;
string print_gpu_name(const string &name) override;
string api_unique_name() override {
return "webgpu";
}
bool kernel_run_takes_types() const override {
return true;
}
protected:
class CodeGen_WGSL : public CodeGen_GPU_C {
public:
CodeGen_WGSL(std::ostream &s, Target t)
: CodeGen_GPU_C(s, t) {
vector_declaration_style = VectorDeclarationStyle::WGSLSyntax;
}
void add_kernel(const Stmt &stmt,
const string &name,
const vector<DeviceArgument> &args);
protected:
using CodeGen_GPU_C::visit;
std::string print_name(const std::string &) override;
std::string print_type(Type type,
AppendSpaceIfNeeded append_space =
DoNotAppendSpace) override;
std::string print_reinterpret(Type type, const Expr &e) override;
std::string print_extern_call(const Call *op) override;
std::string print_assignment(Type t, const std::string &rhs) override;
std::string print_const(Type t, const std::string &rhs);
std::string print_assignment_or_const(Type t, const std::string &rhs,
bool const_expr);
void visit(const Allocate *op) override;
void visit(const And *op) override;
void visit(const AssertStmt *op) override;
void visit(const Broadcast *op) override;
void visit(const Call *op) override;
void visit(const Cast *) override;
void visit(const Div *op) override;
void visit(const Evaluate *op) override;
void visit(const IntImm *) override;
void visit(const UIntImm *) override;
void visit(const FloatImm *) override;
void visit(const Free *op) override;
void visit(const For *) override;
void visit(const Load *op) override;
void visit(const Min *op) override;
void visit(const Max *op) override;
void visit(const Or *op) override;
void visit(const Ramp *op) override;
void visit(const Select *op) override;
void visit(const Store *op) override;
string kernel_name;
std::unordered_set<string> buffers;
std::unordered_set<string> buffers_with_emulated_accesses;
std::unordered_map<string, const Allocate *> workgroup_allocations;
};
std::ostringstream src_stream;
string cur_kernel_name;
CodeGen_WGSL wgsl;
};
CodeGen_WebGPU_Dev::CodeGen_WebGPU_Dev(const Target &t)
: wgsl(src_stream, t) {
}
void CodeGen_WebGPU_Dev::add_kernel(Stmt s,
const string &name,
const vector<DeviceArgument> &args) {
debug(2) << "CodeGen_WebGPU_Dev::add_kernel " << name << "\n";
// We need to scalarize/de-predicate any loads/stores, since WGSL does not
// support predication.
s = scalarize_predicated_loads_stores(s);
debug(2) << "CodeGen_WebGPU_Dev: after removing predication: \n"
<< s;
cur_kernel_name = name;
wgsl.add_kernel(s, name, args);
}
void CodeGen_WebGPU_Dev::init_module() {
debug(2) << "WebGPU device codegen init_module\n";
// Wipe the internal shader source.
src_stream.str("");
src_stream.clear();
// Write out the Halide math functions.
src_stream
<< "fn float_from_bits(x : u32) -> f32 {return bitcast<f32>(x);}\n"
<< "fn nan_f32() -> f32 {return float_from_bits(0x7fc00000);}\n"
<< "fn neg_inf_f32() -> f32 {return float_from_bits(0xff800000);}\n"
<< "fn inf_f32() -> f32 {return float_from_bits(0x7f800000);}\n"
<< "fn acos_f32(x : f32) -> f32 {return acos(x);}\n"
<< "fn acosh_f32(x : f32) -> f32 {return acosh(x);}\n"
<< "fn asin_f32(x : f32) -> f32 {return asin(x);}\n"
<< "fn asinh_f32(x : f32) -> f32 {return asinh(x);}\n"
<< "fn atan_f32(x : f32) -> f32 {return atan(x);}\n"
<< "fn atan2_f32(y : f32, x : f32) -> f32 {return atan2(y, x);}\n"
<< "fn atanh_f32(x : f32) -> f32 {return atanh(x);}\n"
<< "fn ceil_f32(x : f32) -> f32 {return ceil(x);}\n"
<< "fn cos_f32(x : f32) -> f32 {return cos(x);}\n"
<< "fn cosh_f32(x : f32) -> f32 {return cosh(x);}\n"
<< "fn exp_f32(x : f32) -> f32 {return exp(x);}\n"
<< "fn floor_f32(x : f32) -> f32 {return floor(x);}\n"
<< "fn fast_inverse_f32(x : f32) -> f32 {return 1.0 / x;}\n"
<< "fn fast_inverse_sqrt_f32(x : f32) -> f32 {return inverseSqrt(x);}\n"
<< "fn log_f32(x : f32) -> f32 {return log(x);}\n"
// pow() in WGSL has the same semantics as C if x > 0.
// Otherwise, we need to emulate the behavior.
<< "fn pow_f32(x : f32, y : f32) -> f32 { \n"
<< " if (x > 0.0) { \n"
<< " return pow(x, y); \n"
<< " } else if (y == 0.0) { \n"
<< " return 1.0; \n"
<< " } else if (trunc(y) == y) { \n"
<< " if ((y % 2) == 0) { \n"
<< " return pow(abs(x), y); \n"
<< " } else { \n"
<< " return -pow(abs(x), y); \n"
<< " } \n"
<< " } else { \n"
<< " return nan_f32(); \n"
<< " } \n"
<< "} \n"
<< "fn rint(x : f32) -> f32 {return round(x);}\n"
<< "fn round_f32(x : f32) -> f32 {return round(x);}\n"
<< "fn sin_f32(x : f32) -> f32 {return sin(x);}\n"
<< "fn sinh_f32(x : f32) -> f32 {return sinh(x);}\n"
<< "fn sqrt_f32(x : f32) -> f32 {return sqrt(x);}\n"
<< "fn tan_f32(x : f32) -> f32 {return tan(x);}\n"
<< "fn tanh_f32(x : f32) -> f32 {return tanh(x);}\n"
<< "fn trunc_f32(x : f32) -> f32 {return trunc(x);}\n"
// WGSL doesn't provide these by default, but we can exploit the nature
// of comparison ops to construct them... although they are of dubious value
// (since the WGSL spec says that "Implementations may assume that NaNs
// and infinities are not present at runtime"), we'll provide these to
// prevent outright compilation failure, and also as a convenience
// if generating code for an implementaton that is known to preserve them.
<< "fn is_nan_f32(x : f32) -> bool {return x != x;}\n"
<< "fn is_inf_f32(x : f32) -> bool {return !is_nan_f32(x) && is_nan_f32(x - x);}\n"
<< "fn is_finite_f32(x : f32) -> bool {return !is_nan_f32(x) && !is_inf_f32(x);}\n";
// Create pipeline-overridable constants for the workgroup size and
// workgroup array size.
src_stream << "\n"
<< "override wgsize_x : u32;\n"
<< "override wgsize_y : u32;\n"
<< "override wgsize_z : u32;\n"
<< "override workgroup_mem_bytes : u32;\n\n";
}
vector<char> CodeGen_WebGPU_Dev::compile_to_src() {
string str = src_stream.str();
debug(1) << "WGSL shader:\n"
<< str << "\n";
vector<char> buffer(str.begin(), str.end());
buffer.push_back(0);
return buffer;
}
string CodeGen_WebGPU_Dev::get_current_kernel_name() {
return cur_kernel_name;
}
void CodeGen_WebGPU_Dev::dump() {
std::cerr << src_stream.str() << "\n";
}
string CodeGen_WebGPU_Dev::print_gpu_name(const string &name) {
return name;
}
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_name(const string &name) {
string new_name = c_print_name(name);
// The double-underscore prefix is reserved in WGSL.
if (new_name.length() > 1 && new_name[0] == '_' && new_name[1] == '_') {
new_name = "v" + new_name;
}
// Prefix storage buffer and workgroup variable names with the kernel name
// to avoid collisions.
if (buffers.count(name) || workgroup_allocations.count(name)) {
new_name = kernel_name + new_name;
}
return new_name;
}
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_type(Type type,
AppendSpaceIfNeeded space) {
ostringstream oss;
if (type.lanes() != 1) {
switch (type.lanes()) {
case 2:
case 3:
case 4:
oss << "vec" << type.lanes() << "<";
break;
default:
user_error << "Unsupported vector width in WGSL: " << type << "\n";
}
}
if (type.is_float()) {
user_assert(type.bits() == 32) << "WGSL only supports 32-bit floats";
oss << "f32";
} else {
switch (type.bits()) {
case 1:
oss << "bool";
break;
case 8:
case 16:
case 32:
oss << (type.is_uint() ? "u" : "i") << "32";
break;
default:
user_error << "Invalid integer bitwidth for WGSL";
break;
}
}
if (type.lanes() != 1) {
oss << ">";
}
if (space == AppendSpace) {
oss << " ";
}
return oss.str();
}
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_reinterpret(Type type,
const Expr &e) {
ostringstream oss;
oss << "bitcast<" << print_type(type) << ">(" << print_expr(e) << ")";
return oss.str();
}
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_extern_call(const Call *op) {
internal_assert(!function_takes_user_context(op->name)) << op->name;
vector<string> args(op->args.size());
for (size_t i = 0; i < op->args.size(); i++) {
args[i] = print_expr(op->args[i]);
}
ostringstream rhs;
rhs << op->name << "(" << with_commas(args) << ")";
return rhs.str();
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::add_kernel(
const Stmt &s, const string &name, const vector<DeviceArgument> &args) {
debug(2) << "Adding WGSL shader " << name << "\n";
kernel_name = name;
// Look for buffer accesses that will require emulation via atomics.
class FindBufferAccessesRequiringEmulation : public IRVisitor {
using IRVisitor::visit;
void visit(const Load *op) override {
if (op->type.element_of().bits() < 32) {
needs_atomic_accesses.insert(op->name);
}
IRVisitor::visit(op);
}
void visit(const Store *op) override {
if (op->value.type().element_of().bits() < 32) {
needs_atomic_accesses.insert(op->name);
}
IRVisitor::visit(op);
}
public:
std::unordered_set<string> needs_atomic_accesses;
};
FindBufferAccessesRequiringEmulation fbare;
s.accept(&fbare);
// The name of the variable that contains the non-buffer arguments.
string args_var = "Args_" + name;
std::ostringstream uniforms;
uint32_t next_binding = 0;
for (const DeviceArgument &arg : args) {
if (arg.is_buffer) {
// Emit buffer arguments as read_write storage buffers.
buffers.insert(arg.name);
std::string type_decl;
if (fbare.needs_atomic_accesses.count(arg.name)) {
user_warning
<< "buffers of small integer types are currently emulated "
<< "using atomics in the WebGPU backend, and accesses to "
<< "them will be slow.";
buffers_with_emulated_accesses.insert(arg.name);
type_decl = "atomic<u32>";
} else {
type_decl = print_type(arg.type);
}
stream << "@group(0) @binding(" << next_binding << ")\n"
<< "var<storage, read_write> " << print_name(arg.name)
<< " : array<" << type_decl << ">;\n\n";
Allocation alloc;
alloc.type = arg.type;
allocations.push(arg.name, alloc);
next_binding++;
} else {
// Collect non-buffer arguments into a single uniform buffer.
internal_assert(arg.type.bytes() <= 4)
<< "unimplemented: non-buffer args larger than 4 bytes";
uniforms << " " << print_name(arg.name) << " : ";
if (arg.type == Bool()) {
// The bool type cannot appear in a uniform, so use i32 instead.
uniforms << "i32";
} else {
uniforms << print_type(arg.type);
}
uniforms << ",\n";
}
}
if (!uniforms.str().empty()) {
string struct_name = "ArgsStruct_" + name;
stream << "struct " << struct_name << " {\n"
<< uniforms.str()
<< "}\n";
stream << "@group(1) @binding(0)\n"
<< "var<uniform> "
<< args_var << " : " << struct_name << " ;\n\n";
}
// Emit the function prototype.
stream << "@compute @workgroup_size(wgsize_x, wgsize_y, wgsize_z)\n";
stream << "fn " << name << "(\n"
<< " @builtin(local_invocation_id) local_id : vec3<u32>,\n"
<< " @builtin(workgroup_id) group_id : vec3<u32>,\n"
<< ")\n";
open_scope();
stream << get_indent() << "_ = workgroup_mem_bytes;\n";
// Redeclare non-buffer arguments at function scope.
for (const DeviceArgument &arg : args) {
if (!arg.is_buffer) {
stream << get_indent() << "let " << print_name(arg.name)
<< " = " << print_type(arg.type)
<< "(" << args_var << "." << print_name(arg.name) << ");\n";
}
}
// Generate function body.
print(s);
close_scope("shader " + name);
for (auto [name, alloc] : workgroup_allocations) {
std::stringstream length;
if (is_const(alloc->extents[0])) {
length << alloc->extents[0];
} else {
length << "workgroup_mem_bytes / " << alloc->type.bytes();
}
stream << "var<workgroup> " << print_name(name)
<< " : array<" << print_type(alloc->type) << ", "
<< length.str() << ">;\n";
}
workgroup_allocations.clear();
for (const auto &arg : args) {
// Remove buffer arguments from allocation scope and the buffer list.
if (arg.is_buffer) {
buffers.erase(arg.name);
allocations.pop(arg.name);
}
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Allocate *op) {
if (op->memory_type == MemoryType::GPUShared) {
internal_assert(!workgroup_allocations.count(op->name));
workgroup_allocations.insert({op->name, op});
op->body.accept(this);
} else {
open_scope();
debug(2) << "Allocate " << op->name << " on device\n";
// Allocation is not a shared memory allocation, just make a local
// declaration.
// It must have a constant size.
int32_t size = op->constant_allocation_size();
user_assert(size > 0)
<< "Allocation " << op->name << " has a dynamic size. "
<< "Only fixed-size allocations are supported on the gpu. "
<< "Try storing into shared memory instead.";
stream << get_indent() << "var " << print_name(op->name)
<< " : array<" << print_type(op->type) << ", " << size << ">;\n";
Allocation alloc;
alloc.type = op->type;
allocations.push(op->name, alloc);
op->body.accept(this);
// Should have been freed internally
internal_assert(!allocations.contains(op->name));
close_scope("alloc " + print_name(op->name));
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const And *op) {
const Expr &a = op->a;
const Expr &b = op->b;
const Type &t = op->type;
if (t.is_scalar()) {
visit_binop(t, a, b, "&");
} else {
internal_assert(a.type() == b.type());
string sa = print_expr(a);
string sb = print_expr(b);
string rhs = print_type(t) + "(";
for (int i = 0; i < t.lanes(); i++) {
const string si = std::to_string(i);
rhs += sa + "[" + si + "] & " + sb + "[" + si + "], ";
}
rhs += ")";
print_assignment(t, rhs);
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const AssertStmt *op) {
user_warning << "Ignoring assertion inside WebGPU kernel: " << op->condition << "\n";
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Broadcast *op) {
const string id_value = print_expr(op->value);
const Type type = op->type.with_lanes(op->lanes);
print_assignment(type, print_type(type) + "(" + id_value + ")");
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Call *op) {
if (op->is_intrinsic(Call::gpu_thread_barrier)) {
internal_assert(op->args.size() == 1)
<< "gpu_thread_barrier() intrinsic must specify fence type.\n";
const auto *fence_type_ptr = as_const_int(op->args[0]);
internal_assert(fence_type_ptr)
<< "gpu_thread_barrier() parameter is not a constant integer.\n";
auto fence_type = *fence_type_ptr;
stream << get_indent();
if (fence_type & CodeGen_GPU_Dev::MemoryFenceType::Device) {
stream << "storageBarrier();";
}
if (fence_type & CodeGen_GPU_Dev::MemoryFenceType::Shared ||
fence_type == CodeGen_GPU_Dev::MemoryFenceType::None) {
stream << "workgroupBarrier();";
}
stream << "\n";
print_assignment(op->type, "0");
} else if (op->is_intrinsic(Call::if_then_else)) {
internal_assert(op->args.size() == 2 || op->args.size() == 3);
string result_id = unique_name('_');
stream << get_indent() << "var " << result_id
<< " : " << print_type(op->args[1].type()) << ";\n";
// TODO: The rest of this is just copied from the C backend, so maybe
// just introduce an overloadable `print_var_decl` instead.
string cond_id = print_expr(op->args[0]);
stream << get_indent() << "if (" << cond_id << ")\n";
open_scope();
string true_case = print_expr(op->args[1]);
stream << get_indent() << result_id << " = " << true_case << ";\n";
close_scope("if " + cond_id);
if (op->args.size() == 3) {
stream << get_indent() << "else\n";
open_scope();
string false_case = print_expr(op->args[2]);
stream << get_indent() << result_id << " = " << false_case << ";\n";
close_scope("if " + cond_id + " else");
}
print_assignment(op->type, result_id);
} else {
CodeGen_GPU_C::visit(op);
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Cast *op) {
print_assignment(op->type,
print_type(op->type) + "(" + print_expr(op->value) + ")");
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Div *op) {
int bits;
if (is_const_power_of_two_integer(op->b, &bits)) {
// WGSL requires the RHS of a shift to be unsigned.
Type uint_type = op->a.type().with_code(halide_type_uint);
visit_binop(op->type, op->a, make_const(uint_type, bits), ">>");
} else {
CodeGen_GPU_C::visit(op);
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const IntImm *op) {
print_const(op->type, std::to_string(op->value));
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const UIntImm *op) {
if (op->type == Bool()) {
if (op->value == 1) {
id = "true";
} else {
id = "false";
}
} else {
print_const(op->type, std::to_string(op->value) + "u");
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const FloatImm *op) {
string rhs;
if (std::isnan(op->value)) {
rhs = "0x7FFFFFFF";
} else if (std::isinf(op->value)) {
if (op->value > 0) {
rhs = "0x7F800000";
} else {
rhs = "0xFF800000";
}
} else {
// Write the constant as reinterpreted uint to avoid any bits lost in
// conversion.
union {
uint32_t as_uint;
float as_float;
} u;
u.as_float = op->value;
ostringstream oss;
oss << "float_from_bits("
<< u.as_uint << "u /* " << u.as_float << " */)";
rhs = oss.str();
}
print_assignment(op->type, rhs);
}
namespace {
string simt_intrinsic(const string &name) {
if (ends_with(name, ".__thread_id_x")) {
return "local_id.x";
} else if (ends_with(name, ".__thread_id_y")) {
return "local_id.y";
} else if (ends_with(name, ".__thread_id_z")) {
return "local_id.z";
} else if (ends_with(name, ".__thread_id_w")) {
user_error << "WebGPU does not support more than three dimensions.\n";
} else if (ends_with(name, ".__block_id_x")) {
return "group_id.x";
} else if (ends_with(name, ".__block_id_y")) {
return "group_id.y";
} else if (ends_with(name, ".__block_id_z")) {
return "group_id.z";
} else if (ends_with(name, ".__block_id_w")) {
user_error << "WebGPU does not support more than three dimensions.\n";
}
internal_error << "invalid simt_intrinsic name: " << name << "\n";
return "";
}
} // namespace
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Evaluate *op) {
if (is_const(op->value)) {
return;
}
print_expr(op->value);
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Free *op) {
if (workgroup_allocations.count(op->name)) {
return;
} else {
// Should have been freed internally
internal_assert(allocations.contains(op->name));
allocations.pop(op->name);
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const For *loop) {
user_assert(loop->for_type != ForType::GPULane)
<< "The WebGPU backend does not support the gpu_lanes() directive.";
if (is_gpu_var(loop->name)) {
internal_assert((loop->for_type == ForType::GPUBlock) ||
(loop->for_type == ForType::GPUThread))
<< "kernel loop must be either gpu block or gpu thread\n";
internal_assert(is_const_zero(loop->min));
stream << get_indent()
<< "let " << print_name(loop->name)
<< " = i32(" << simt_intrinsic(loop->name) << ");\n";
loop->body.accept(this);
} else {
user_assert(loop->for_type == ForType::Serial)
<< "Can only use serial loops inside WebGPU shaders\n";
string id_min = print_expr(loop->min);
string id_extent = print_expr(loop->extent);
string id_counter = print_name(loop->name);
stream << get_indent() << "for (var "
<< id_counter << " = " << id_min << "; "
<< id_counter << " < " << id_min << " + " << id_extent << "; "
// TODO: Use increment statement when supported by Chromium.
<< id_counter << " = " << id_counter << " + 1)\n";
open_scope();
loop->body.accept(this);
close_scope("for " + print_name(loop->name));
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Load *op) {
user_assert(is_const_one(op->predicate))
<< "Predicated loads are not supported for WebGPU.\n";
Type result_type = op->type.element_of();
// Get the allocation type, which may be different from the result type.
Type alloc_type = result_type;
if (allocations.contains(op->name)) {
alloc_type = allocations.get(op->name).type;
} else if (workgroup_allocations.count(op->name)) {
alloc_type = workgroup_allocations.at(op->name)->type;
}
const int bits = result_type.bits();
const string name = print_name(op->name);
const string bits_str = std::to_string(bits);
const string elements = std::to_string(32 / bits);
// Cast a loaded value to the result type if necessary,
auto cast_if_needed = [&](const string &value) {
if (result_type != alloc_type) {
return print_type(result_type) + "(" + value + ")";
} else {
return value;
}
};
// Load an 8- or 16-bit value from an array<atomic<u32>>.
auto emulate_narrow_load = [&](const string &idx) {
internal_assert(bits == 8 || bits == 16);
internal_assert(!op->type.is_float());
// Generated code (16-bit):
// (atomicLoad(&in.data[i/2]) >> u32((i%2)*16)) & 0xFFFFu;
string load;
load = "atomicLoad(&" + name + "[" + idx + " / " + elements + "])";
load += " >> u32((" + idx + " % " + elements + ") * " + bits_str + ")";
load = "(" + load + ") & " + std::to_string((1 << bits) - 1) + "u";
if (op->type.is_int()) {
// Convert to i32 and sign-extend.
const string shift = std::to_string(32 - bits);
load = "i32((" + load + ") << " + shift + "u) >> " + shift + "u";
}
return load;
};
// TODO: Use cache to avoid re-loading same value multiple times.
const string idx = print_expr(op->index);
if (op->type.is_scalar()) {
string rhs;
if (buffers_with_emulated_accesses.count(op->name)) {
if (bits == 32) {
rhs = "bitcast<" + print_type(result_type) +
">(atomicLoad(&" + name + "[" + idx + "]))";
} else {
rhs = emulate_narrow_load(idx);
}
} else {
rhs = name + "[" + idx + "]";
}
print_assignment(op->type, cast_if_needed(rhs));
return;
} else if (op->type.is_vector()) {
id = "_" + unique_name('V');
// TODO: Could be smarter about this for a dense ramp.
stream << get_indent()
<< "var " << id << " : " << print_type(op->type) << ";\n";
for (int i = 0; i < op->type.lanes(); ++i) {
stream << get_indent() << id << "[" << i << "] = ";
const string idx_i = idx + "[" + std::to_string(i) + "]";
string rhs;
if (buffers_with_emulated_accesses.count(op->name)) {
if (bits == 32) {
rhs = "bitcast<" + print_type(result_type) +
">(atomicLoad(&" + name + "[" + idx_i + "]))";
} else {
rhs = emulate_narrow_load(idx_i);
}
} else {
rhs = name + "[" + idx_i + "]";
}
stream << cast_if_needed(rhs) << ";\n";
}
return;
}
internal_error << "unhandled type of load for WGSL";
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Max *op) {
print_expr(Call::make(op->type, "max", {op->a, op->b}, Call::Extern));
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Min *op) {
print_expr(Call::make(op->type, "min", {op->a, op->b}, Call::Extern));
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Or *op) {
const Expr &a = op->a;
const Expr &b = op->b;
const Type &t = op->type;
if (t.is_scalar()) {
visit_binop(t, a, b, "|");
} else {
internal_assert(a.type() == b.type());
string sa = print_expr(a);
string sb = print_expr(b);
string rhs = print_type(t) + "(";
for (int i = 0; i < t.lanes(); i++) {
const string si = std::to_string(i);
rhs += sa + "[" + si + "] | " + sb + "[" + si + "], ";
}
rhs += ")";
print_assignment(t, rhs);
}
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Ramp *op) {
string id_base = print_expr(op->base);
string id_stride = print_expr(op->stride);
ostringstream rhs;
rhs << id_base << " + " << id_stride << " * "
<< print_type(op->type.with_lanes(op->lanes)) << "(0";
// Note 0 written above.
for (int i = 1; i < op->lanes; ++i) {
rhs << ", " << i;
}
rhs << ")";
print_assignment(op->type.with_lanes(op->lanes), rhs.str());
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Select *op) {
string true_val = print_expr(op->true_value);
string false_val = print_expr(op->false_value);
string cond = print_expr(op->condition);
string select = "select(" + false_val + ", " + true_val + ", " + cond + ")";
print_assignment(op->type, select);
}
void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Store *op) {
user_assert(is_const_one(op->predicate))
<< "Predicated stores are not supported for WebGPU.\n";
Type value_type = op->value.type().element_of();
// Get the allocation type, which may be different from the value type.
Type alloc_type = value_type;
if (allocations.contains(op->name)) {
alloc_type = allocations.get(op->name).type;
} else if (workgroup_allocations.count(op->name)) {
alloc_type = workgroup_allocations.at(op->name)->type;
}
// Cast a value to the store type if necessary,
auto cast_if_needed = [&](const string &value) {
if (alloc_type != value_type) {
return print_type(alloc_type) + "(" + value + ")";
} else {
return value;
}
};
const int bits = value_type.bits();
const string name = print_name(op->name);
const string bits_str = std::to_string(bits);
const string elements = std::to_string(32 / bits);
// Store an 8- or 16-bit value to an array<atomic<u32>>.
auto emulate_narrow_store = [&](const string &idx, const string &value) {
internal_assert(bits == 8 || bits == 16);
internal_assert(!op->value.type().is_float());
// Generated code (16-bits):
// let shift = u32(i % 2) * 16u;
// var old = atomicLoad(&out[i / 2u]);
// while (true) {
// let mask = ((old >> shift) ^ bitcast<u32>(value)) & 0xFFFFu;
// let newval = old ^ (mask << shift);
// let result = atomicCompareExchangeWeak(&out[i / 2u], old, newval);
// if (result.exchanged) {
// break;
// }
// old = result.old_value;
// }
const string shift = "_" + unique_name('S');
const string old = "_" + unique_name('O');
stream << get_indent() << "let " << shift << " = u32(" << idx << " % "
<< elements << ") * " << bits_str << "u;\n";
stream << get_indent() << "var " << old << " = atomicLoad(&"
<< name << "[" << idx << " / " << elements << "]);\n";
stream << get_indent() << "for (;;) {\n";
stream << get_indent() << " let mask = ((" << old << " >> "
<< shift << ") ^ bitcast<u32>(" << value << ")) & "
<< std::to_string((1 << bits) - 1) << "u;\n";
stream << get_indent() << " let newval = " << old << " ^ (mask << "
<< shift << ");\n";
stream << get_indent() << " let result = atomicCompareExchangeWeak(&"
<< name << "[" << idx << " / " << elements << "], "
<< old << ", newval);\n";
stream << get_indent() << " if (result.exchanged) { break; }\n";
stream << get_indent() << " " << old << " = result.old_value;\n";
stream << get_indent() << "}\n";
};
const string idx = print_expr(op->index);
const string rhs = print_expr(op->value);
if (op->value.type().is_scalar()) {
if (buffers_with_emulated_accesses.count(op->name)) {
if (bits == 32) {
stream << get_indent() << "atomicStore(&"
<< name << "[" << idx << "], "
<< "bitcast<u32>(" << rhs << "));\n";
} else {
emulate_narrow_store(idx, rhs);
}
} else {
stream << get_indent() << name << "[" << idx << "] = ";
stream << cast_if_needed(rhs) << ";\n";
}
} else if (op->value.type().is_vector()) {
// TODO: Could be smarter about this for a dense ramp.
for (int i = 0; i < op->value.type().lanes(); ++i) {
const string idx_i = idx + "[" + std::to_string(i) + "]";
string value_i = rhs + "[" + std::to_string(i) + "]";
if (buffers_with_emulated_accesses.count(op->name)) {
if (bits == 32) {
stream << get_indent() << "atomicStore(&"
<< name << "[" << idx_i << "], "
<< "bitcast<u32>(" << value_i << "));\n";
} else {
emulate_narrow_store(idx_i, value_i);
}
} else {
stream << get_indent() << name << "[" << idx_i << "] = ";
stream << cast_if_needed(value_i) << ";\n";
}
}
}
// Need a cache clear on stores to avoid reusing stale loaded
// values from before the store.
cache.clear();
}
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_assignment(
Type t, const std::string &rhs) {
return print_assignment_or_const(t, rhs, false);
}
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_const(
Type t, const std::string &rhs) {
return print_assignment_or_const(t, rhs, true);
}
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_assignment_or_const(
Type t, const std::string &rhs, bool const_expr) {
auto cached = cache.find(rhs);
if (cached == cache.end()) {
id = unique_name('_');
stream << get_indent() << (const_expr ? "const" : "let") << " " << id
<< " : " << print_type(t) << " = " << rhs << ";\n";
cache[rhs] = id;
} else {
id = cached->second;
}
return id;
}
} // namespace
std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_WebGPU_Dev(const Target &target) {
return std::make_unique<CodeGen_WebGPU_Dev>(target);
}
} // namespace Internal
} // namespace Halide