https://github.com/halide/Halide
Raw File
Tip revision: 0174bb20116efa28a1b0e46efe4da7f195c0fe0d authored by Aelphy on 28 February 2024, 14:30:09 UTC
[xtensa] Added gather_load for float16
Tip revision: 0174bb2
webgpu.cpp
#include "HalideRuntimeWebGPU.h"
#include "device_buffer_utils.h"
#include "device_interface.h"
#include "gpu_context_common.h"
#include "printer.h"
#include "runtime_atomics.h"
#include "scoped_spin_lock.h"

#include "mini_webgpu.h"

#ifndef HALIDE_RUNTIME_WEBGPU_NATIVE_API
#error "HALIDE_RUNTIME_WEBGPU_NATIVE_API must be defined"
#endif

namespace Halide {
namespace Runtime {
namespace Internal {
namespace WebGPU {

extern WEAK halide_device_interface_t webgpu_device_interface;

WEAK int create_webgpu_context(void *user_context);

// A WebGPU instance/adapter/device defined in this module with weak linkage.
WEAK WGPUInstance global_instance = nullptr;
WEAK WGPUAdapter global_adapter = nullptr;
WEAK WGPUDevice global_device = nullptr;
// Lock to synchronize access to the global WebGPU context.
volatile ScopedSpinLock::AtomicFlag WEAK context_lock = 0;

// A staging buffer used for host<->device copies.
WEAK WGPUBuffer global_staging_buffer = nullptr;

// A flag to signify that the WebGPU device was lost.
bool device_was_lost = false;

}  // namespace WebGPU
}  // namespace Internal
}  // namespace Runtime
}  // namespace Halide

using namespace Halide::Runtime::Internal;
using namespace Halide::Runtime::Internal::WebGPU;

extern "C" {
// TODO: Remove all of this when wgpuInstanceProcessEvents() is supported.
// See https://github.com/halide/Halide/issues/7248
#if HALIDE_RUNTIME_WEBGPU_NATIVE_API
// Defined by Dawn, and used to yield execution to asynchronous commands.
void wgpuDeviceTick(WGPUDevice);
// From <unistd.h>, used to spin-lock while waiting for device initialization.
int usleep(uint32_t);
#else
// Defined by Emscripten, and used to yield execution to asynchronous Javascript
// work in combination with Emscripten's "Asyncify" mechanism.
void emscripten_sleep(unsigned int ms);
// Wrap emscripten_sleep in wgpuDeviceTick() to unify usage with Dawn.
void wgpuDeviceTick(WGPUDevice) {
    emscripten_sleep(1);
}
#endif

// The default implementation of halide_webgpu_acquire_context uses the global
// pointers above, and serializes access with a spin lock.
// Overriding implementations of acquire/release must implement the following
// behavior:
// - halide_webgpu_acquire_context should always store a valid
//   instance/adapter/device in instance_ret/adapter_ret/device_ret, or return
//   an error code.
// - A call to halide_webgpu_acquire_context is followed by a matching call to
//   halide_webgpu_release_context. halide_webgpu_acquire_context should block
//   while a previous call (if any) has not yet been released via
//   halide_webgpu_release_context.
WEAK int halide_webgpu_acquire_context(void *user_context,
                                       WGPUInstance *instance_ret,
                                       WGPUAdapter *adapter_ret,
                                       WGPUDevice *device_ret,
                                       WGPUBuffer *staging_buffer_ret,
                                       bool create = true) {
    debug(user_context)
        << "WGPU: halide_webgpu_acquire_context (user_context: " << user_context
        << ")\n";
    halide_abort_if_false(user_context, &context_lock != nullptr);
    while (__atomic_test_and_set(&context_lock, __ATOMIC_ACQUIRE)) {
    }

    if (create && (global_device == nullptr)) {
        int status = create_webgpu_context(user_context);
        if (status != halide_error_code_success) {
            __atomic_clear(&context_lock, __ATOMIC_RELEASE);
            return status;
        }
    }

    if (device_was_lost) {
        return halide_error_code_generic_error;
    }

    *instance_ret = global_instance;
    *adapter_ret = global_adapter;
    *device_ret = global_device;
    *staging_buffer_ret = global_staging_buffer;

    return halide_error_code_success;
}

WEAK int halide_webgpu_release_context(void *user_context) {
    __atomic_clear(&context_lock, __ATOMIC_RELEASE);
    return halide_error_code_success;
}

}  // extern "C" linkage

namespace Halide {
namespace Runtime {
namespace Internal {
namespace WebGPU {

// Helper object to acquire and release the WebGPU context.
class WgpuContext {
    void *user_context;

public:
    WGPUInstance instance = nullptr;
    WGPUAdapter adapter = nullptr;
    WGPUDevice device = nullptr;
    WGPUQueue queue = nullptr;

    // A staging buffer used for host<->device copies.
    WGPUBuffer staging_buffer = nullptr;

    int error_code = 0;

    ALWAYS_INLINE explicit WgpuContext(void *user_context)
        : user_context(user_context) {
        error_code = halide_webgpu_acquire_context(
            user_context, &instance, &adapter, &device, &staging_buffer);
        if (error_code == halide_error_code_success) {
            queue = wgpuDeviceGetQueue(device);
        }
    }

    ALWAYS_INLINE ~WgpuContext() {
        if (queue) {
            wgpuQueueRelease(queue);
        }
        (void)halide_webgpu_release_context(user_context);  // ignore errors
    }
};

// Helper class for handling asynchronous errors for a set of WebGPU API calls
// within a particular scope.
class ErrorScope {
public:
    ALWAYS_INLINE ErrorScope(void *user_context, WGPUDevice device)
        : user_context(user_context), device(device) {
        // Capture validation and OOM errors.
        wgpuDevicePushErrorScope(device, WGPUErrorFilter_Validation);
        wgpuDevicePushErrorScope(device, WGPUErrorFilter_OutOfMemory);
        callbacks_remaining = 2;
    }

    ALWAYS_INLINE ~ErrorScope() {
        if (callbacks_remaining > 0) {
            // Pop the error scopes to flush any pending errors.
            wait();
        }
    }

    // Wait for all error callbacks in this scope to fire.
    // Returns the error code (or success).
    halide_error_code_t wait() {
        using namespace Halide::Runtime::Internal::Synchronization;

        if (callbacks_remaining == 0) {
            error(user_context) << "no outstanding error scopes\n";
            return halide_error_code_internal_error;
        }

        error_code = halide_error_code_success;
        wgpuDevicePopErrorScope(device, error_callback, this);
        wgpuDevicePopErrorScope(device, error_callback, this);

        // Wait for the error callbacks to fire.
        while (atomic_fetch_or_sequentially_consistent(&callbacks_remaining, 0) > 0) {
            wgpuDeviceTick(device);
        }

        return error_code;
    }

private:
    void *user_context;
    WGPUDevice device;

    // The error code reported by the callback functions.
    volatile halide_error_code_t error_code;

    // Used to track outstanding error callbacks.
    volatile int callbacks_remaining = 0;

    // The error callback function.
    // Logs any errors, and decrements the remaining callback count.
    static void error_callback(WGPUErrorType type,
                               char const *message,
                               void *userdata) {
        using namespace Halide::Runtime::Internal::Synchronization;

        ErrorScope *context = (ErrorScope *)userdata;
        switch (type) {
        case WGPUErrorType_NoError:
            // Do not overwrite the error_code to avoid masking earlier errors.
            break;
        case WGPUErrorType_Validation:
            error(context->user_context) << "WGPU: validation error: "
                                         << message << "\n";
            context->error_code = halide_error_code_generic_error;
            break;
        case WGPUErrorType_OutOfMemory:
            error(context->user_context) << "WGPU: out-of-memory error: "
                                         << message << "\n";
            context->error_code = halide_error_code_out_of_memory;
            break;
        default:
            error(context->user_context) << "WGPU: unknown error (" << type
                                         << "): " << message << "\n";
            context->error_code = halide_error_code_generic_error;
        }

        atomic_fetch_add_sequentially_consistent(&context->callbacks_remaining, -1);
    }
};

// WgpuBufferHandle represents a device buffer with an offset.
struct WgpuBufferHandle {
    uint64_t offset;
    WGPUBuffer buffer;
};

// A cache for compiled WGSL shader modules.
WEAK Halide::Internal::GPUCompilationCache<WGPUDevice, WGPUShaderModule>
    shader_cache;

namespace {

halide_error_code_t init_error_code = halide_error_code_success;

void device_lost_callback(WGPUDeviceLostReason reason,
                          char const *message,
                          void *user_context) {
    // Apparently this should not be treated as a fatal error
    if (reason == WGPUDeviceLostReason_Destroyed) {
        return;
    }
    error(user_context) << "WGPU device lost (" << reason << "): "
                        << message << "\n";
}

void request_device_callback(WGPURequestDeviceStatus status,
                             WGPUDevice device,
                             char const *message,
                             void *user_context) {
    if (status != WGPURequestDeviceStatus_Success) {
        error(user_context) << "wgpuAdapterRequestDevice failed ("
                            << status << "): " << message << "\n";
        init_error_code = halide_error_code_generic_error;
        return;
    }
    device_was_lost = false;
    global_device = device;
}

void request_adapter_callback(WGPURequestAdapterStatus status,
                              WGPUAdapter adapter,
                              char const *message,
                              void *user_context) {
    if (status != WGPURequestAdapterStatus_Success) {
        debug(user_context) << "wgpuInstanceRequestAdapter failed: ("
                            << status << "): " << message << "\n";
        init_error_code = halide_error_code_generic_error;
        return;
    }
    global_adapter = adapter;

    // Use the defaults for most limits.
    WGPURequiredLimits requestedLimits{};
    requestedLimits.nextInChain = nullptr;
    memset(&requestedLimits.limits, 0xFF, sizeof(WGPULimits));

    // TODO: Enable for Emscripten when wgpuAdapterGetLimits is supported.
    // See https://github.com/halide/Halide/issues/7248
#if HALIDE_RUNTIME_WEBGPU_NATIVE_API
    WGPUSupportedLimits supportedLimits{};
    supportedLimits.nextInChain = nullptr;
    if (!wgpuAdapterGetLimits(adapter, &supportedLimits)) {
        debug(user_context) << "wgpuAdapterGetLimits failed\n";
    } else {
        // Raise the limits on buffer size and workgroup storage size.
        requestedLimits.limits.maxBufferSize =
            supportedLimits.limits.maxBufferSize;
        requestedLimits.limits.maxStorageBufferBindingSize =
            supportedLimits.limits.maxStorageBufferBindingSize;
        requestedLimits.limits.maxComputeWorkgroupStorageSize =
            supportedLimits.limits.maxComputeWorkgroupStorageSize;
    }
#endif

    WGPUDeviceDescriptor desc{};
    desc.nextInChain = nullptr;
    desc.label = nullptr;
    desc.requiredFeatureCount = 0;
    desc.requiredFeatures = nullptr;
    desc.requiredLimits = &requestedLimits;
    desc.deviceLostCallback = device_lost_callback;

    wgpuAdapterRequestDevice(adapter, &desc, request_device_callback,
                             user_context);
}

size_t round_up_to_multiple_of_4(size_t x) {
    return (x + 3) & ~0x3;
}

}  // namespace

WEAK int create_webgpu_context(void *user_context) {
    debug(user_context)
        << "WGPU: create_webgpu_context (user_context: " << user_context
        << ")\n";

    global_instance = wgpuCreateInstance(nullptr);
    debug(user_context)
        << "WGPU: wgpuCreateInstance produces: " << global_instance
        << ")\n";

    debug(user_context)
        << "WGPU: global_instance is: (" << global_instance
        << ")\n";
    wgpuInstanceRequestAdapter(
        global_instance, nullptr, request_adapter_callback, user_context);

    // Wait for device initialization to complete.
    while (!global_device && init_error_code == halide_error_code_success) {
        // TODO: Use wgpuInstanceProcessEvents() when it is supported.
        // See https://github.com/halide/Halide/issues/7248
#if HALIDE_RUNTIME_WEBGPU_NATIVE_API
        usleep(1000);
#else
        emscripten_sleep(10);
#endif
    }
    if (init_error_code != halide_error_code_success) {
        return init_error_code;
    }

    // Create a staging buffer for transfers.
    constexpr int kStagingBufferSize = 4 * 1024 * 1024;
    WGPUBufferDescriptor buffer_desc{};
    buffer_desc.nextInChain = nullptr;
    buffer_desc.label = nullptr;
    buffer_desc.usage = WGPUBufferUsage_CopyDst | WGPUBufferUsage_MapRead;
    buffer_desc.size = kStagingBufferSize;
    buffer_desc.mappedAtCreation = false;

    ErrorScope error_scope(user_context, global_device);
    global_staging_buffer = wgpuDeviceCreateBuffer(global_device, &buffer_desc);

    halide_error_code_t error_code = error_scope.wait();
    if (error_code != halide_error_code_success) {
        global_staging_buffer = nullptr;
        init_error_code = error_code;
    }

    return init_error_code;
}

}  // namespace WebGPU
}  // namespace Internal
}  // namespace Runtime
}  // namespace Halide

using namespace Halide::Runtime::Internal::WebGPU;

extern "C" {

WEAK int halide_webgpu_device_malloc(void *user_context, halide_buffer_t *buf) {
    debug(user_context)
        << "WGPU: halide_webgpu_device_malloc (user_context: " << user_context
        << ", buf: " << buf << ")\n";

    if (buf->device) {
        return halide_error_code_success;
    }

    WgpuContext context(user_context);
    if (context.error_code) {
        return context.error_code;
    }

    ErrorScope error_scope(user_context, context.device);

    WGPUBufferDescriptor desc{};
    desc.nextInChain = nullptr;
    desc.label = nullptr;
    desc.usage = WGPUBufferUsage_Storage |
                 WGPUBufferUsage_CopyDst |
                 WGPUBufferUsage_CopySrc;
    desc.size = round_up_to_multiple_of_4(buf->size_in_bytes());
    desc.mappedAtCreation = false;

    WgpuBufferHandle *device_handle =
        (WgpuBufferHandle *)malloc(sizeof(WgpuBufferHandle));
    device_handle->buffer = wgpuDeviceCreateBuffer(context.device, &desc);
    device_handle->offset = 0;

    int error_code = error_scope.wait();
    if (error_code != halide_error_code_success) {
        return error_code;
    }

    buf->device = (uint64_t)device_handle;
    buf->device_interface = &webgpu_device_interface;
    buf->device_interface->impl->use_module();

    debug(user_context)
        << "      Allocated device buffer " << (void *)buf->device << "\n";

    return halide_error_code_success;
}

WEAK int halide_webgpu_device_free(void *user_context, halide_buffer_t *buf) {
    if (buf->device == 0) {
        return halide_error_code_success;
    }

    WgpuBufferHandle *handle = (WgpuBufferHandle *)buf->device;

    debug(user_context)
        << "WGPU: halide_webgpu_device_free (user_context: " << user_context
        << ", buf: " << buf << ") WGPUBuffer: " << handle->buffer << "\n";

    WgpuContext context(user_context);
    if (context.error_code) {
        return context.error_code;
    }

    wgpuBufferRelease(handle->buffer);
    free(handle);
    buf->device = 0;
    buf->device_interface->impl->release_module();
    buf->device_interface = nullptr;

    return halide_error_code_success;
}

WEAK int halide_webgpu_device_sync(void *user_context, halide_buffer_t *) {
    WgpuContext context(user_context);
    if (context.error_code) {
        return context.error_code;
    }

    ErrorScope error_scope(user_context, context.device);

    // Wait for all work on the queue to finish.
    struct WorkDoneResult {
        volatile ScopedSpinLock::AtomicFlag complete = false;
        volatile WGPUQueueWorkDoneStatus status;
    };
    WorkDoneResult result;

    __atomic_test_and_set(&result.complete, __ATOMIC_RELAXED);
    wgpuQueueOnSubmittedWorkDone(
        context.queue,
        [](WGPUQueueWorkDoneStatus status, void *userdata) {
            WorkDoneResult *result = (WorkDoneResult *)userdata;
            result->status = status;
            __atomic_clear(&result->complete, __ATOMIC_RELEASE);
        },
        &result);

    int error_code = error_scope.wait();
    if (error_code != halide_error_code_success) {
        return error_code;
    }

    while (__atomic_test_and_set(&result.complete, __ATOMIC_ACQUIRE)) {
        wgpuDeviceTick(context.device);
    }

    if (result.status != WGPUQueueWorkDoneStatus_Success) {
        halide_error(user_context, "wgpuQueueOnSubmittedWorkDone failed");
        return halide_error_code_device_sync_failed;
    }
    return halide_error_code_success;
}

WEAK int halide_webgpu_device_release(void *user_context) {
    debug(user_context)
        << "WGPU: halide_webgpu_device_release (user_context: " << user_context
        << ")\n";

    // The WgpuContext object does not allow the context storage to be modified,
    // so we use halide_acquire_context directly.
    int err;
    WGPUInstance instance;
    WGPUAdapter adapter;
    WGPUDevice device;
    WGPUBuffer staging_buffer;
    err = halide_webgpu_acquire_context(user_context,
                                        &instance, &adapter, &device, &staging_buffer, false);
    if (err != halide_error_code_success) {
        return err;
    }

    if (device) {
        shader_cache.delete_context(user_context, device,
                                    wgpuShaderModuleRelease);

        // Release the device/adapter/instance/staging_buffer, if we created them.
        if (device == global_device) {
            if (staging_buffer) {
                wgpuBufferRelease(staging_buffer);
                global_staging_buffer = nullptr;
            }

            wgpuDeviceRelease(device);
            global_device = nullptr;

            wgpuAdapterRelease(adapter);
            global_adapter = nullptr;

            wgpuInstanceRelease(instance);
            global_instance = nullptr;
        }
    }

    return halide_webgpu_release_context(user_context);
}

WEAK int halide_webgpu_device_and_host_malloc(void *user_context,
                                              struct halide_buffer_t *buf) {
    return halide_default_device_and_host_malloc(user_context, buf,
                                                 &webgpu_device_interface);
}

WEAK int halide_webgpu_device_and_host_free(void *user_context,
                                            struct halide_buffer_t *buf) {
    return halide_default_device_and_host_free(user_context, buf,
                                               &webgpu_device_interface);
}

namespace {

// Copy `size` bytes of data from buffer `src` to a host pointer `dst`.
int do_copy_to_host(void *user_context, WgpuContext *context, uint8_t *dst,
                    WGPUBuffer src, int64_t src_offset, int64_t size) {
    // Copy chunks via the staging buffer.
    int64_t staging_buffer_size = wgpuBufferGetSize(context->staging_buffer);
    for (int64_t offset = 0; offset < size; offset += staging_buffer_size) {
        int64_t num_bytes = staging_buffer_size;
        if (offset + num_bytes > size) {
            num_bytes = size - offset;
        }

        // Copy this chunk to the staging buffer.
        WGPUCommandEncoder encoder =
            wgpuDeviceCreateCommandEncoder(context->device, nullptr);
        wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset + offset,
                                             context->staging_buffer,
                                             0, num_bytes);
        WGPUCommandBuffer command_buffer =
            wgpuCommandEncoderFinish(encoder, nullptr);
        wgpuQueueSubmit(context->queue, 1, &command_buffer);

        struct BufferMapResult {
            volatile ScopedSpinLock::AtomicFlag map_complete;
            volatile WGPUBufferMapAsyncStatus map_status;
        };
        BufferMapResult result;

        // Map the staging buffer for reading.
        __atomic_test_and_set(&result.map_complete, __ATOMIC_RELAXED);
        wgpuBufferMapAsync(
            context->staging_buffer, WGPUMapMode_Read, 0, num_bytes,
            [](WGPUBufferMapAsyncStatus status, void *userdata) {
                BufferMapResult *result = (BufferMapResult *)userdata;
                result->map_status = status;
                __atomic_clear(&result->map_complete, __ATOMIC_RELEASE);
            },
            &result);

        while (__atomic_test_and_set(&result.map_complete, __ATOMIC_ACQUIRE)) {
            wgpuDeviceTick(context->device);
        }
        if (result.map_status != WGPUBufferMapAsyncStatus_Success) {
            error(user_context) << "wgpuBufferMapAsync failed: "
                                << result.map_status << "\n";
            return halide_error_code_copy_to_host_failed;
        }

        // Copy the data from the mapped staging buffer to the host allocation.
        const void *src = wgpuBufferGetConstMappedRange(context->staging_buffer,
                                                        0, num_bytes);
        memcpy(dst + offset, src, num_bytes);
        wgpuBufferUnmap(context->staging_buffer);
    }

    return halide_error_code_success;
}

int do_multidimensional_copy(void *user_context, WgpuContext *context,
                             const device_copy &c,
                             int64_t src_idx, int64_t dst_idx,
                             int d, bool from_host, bool to_host) {
    if (d > MAX_COPY_DIMS) {
        error(user_context)
            << "Buffer has too many dimensions to copy to/from GPU\n";
        return halide_error_code_bad_dimensions;
    } else if (d == 0) {
        int err = 0;

        WgpuBufferHandle *src = (WgpuBufferHandle *)(c.src);
        WgpuBufferHandle *dst = (WgpuBufferHandle *)(c.dst);

        debug(user_context) << "    from " << (from_host ? "host" : "device")
                            << " to " << (to_host ? "host" : "device") << ", "
                            << (void *)c.src << " + " << src_idx
                            << " -> " << (void *)c.dst << " + " << dst_idx
                            << ", " << c.chunk_size << " bytes\n";
        uint64_t copy_size = round_up_to_multiple_of_4(c.chunk_size);
        if (!from_host && to_host) {
            err = do_copy_to_host(user_context, context,
                                  (uint8_t *)(c.dst + dst_idx),
                                  src->buffer, src_idx + src->offset,
                                  copy_size);
        } else if (from_host && !to_host) {
            wgpuQueueWriteBuffer(context->queue, dst->buffer,
                                 dst_idx + dst->offset,
                                 (void *)(c.src + src_idx), copy_size);
        } else if (!from_host && !to_host) {
            // Create a command encoder and encode a copy command.
            WGPUCommandEncoder encoder =
                wgpuDeviceCreateCommandEncoder(context->device, nullptr);
            wgpuCommandEncoderCopyBufferToBuffer(encoder,
                                                 src->buffer,
                                                 src_idx + src->offset,
                                                 dst->buffer,
                                                 dst_idx + dst->offset,
                                                 c.chunk_size);

            // Submit the copy command.
            WGPUCommandBuffer cmd = wgpuCommandEncoderFinish(encoder, nullptr);
            wgpuQueueSubmit(context->queue, 1, &cmd);
            wgpuCommandEncoderRelease(encoder);
        } else if ((c.dst + dst_idx) != (c.src + src_idx)) {
            // Could reach here if a user called directly into the
            // WebGPU API for a device->host copy on a source buffer
            // with device_dirty = false.
            halide_debug_assert(user_context, false && "unimplemented");
        }

        return err;
    } else {
        ssize_t src_off = 0, dst_off = 0;
        for (int i = 0; i < (int)c.extent[d - 1]; i++) {
            int err = do_multidimensional_copy(user_context, context, c,
                                               src_idx + src_off,
                                               dst_idx + dst_off,
                                               d - 1, from_host, to_host);
            dst_off += c.dst_stride_bytes[d - 1];
            src_off += c.src_stride_bytes[d - 1];
            if (err) {
                return err;
            }
        }
    }
    return halide_error_code_success;
}

}  // namespace

WEAK int halide_webgpu_buffer_copy(void *user_context,
                                   struct halide_buffer_t *src,
                                   const struct halide_device_interface_t *dst_device_interface,
                                   struct halide_buffer_t *dst) {
    debug(user_context)
        << "WGPU: halide_webgpu_buffer_copy (user_context: " << user_context
        << ", src: " << src << ", dst: " << dst << ")\n";

    // We only handle copies between WebGPU devices or to/from the host.
    halide_abort_if_false(user_context,
                          dst_device_interface == nullptr ||
                              dst_device_interface == &webgpu_device_interface);

    if ((src->device_dirty() || src->host == nullptr) &&
        src->device_interface != &webgpu_device_interface) {
        halide_abort_if_false(user_context,
                              dst_device_interface == &webgpu_device_interface);
        // This is handled at the higher level.
        return halide_error_code_incompatible_device_interface;
    }

    bool from_host = (src->device_interface != &webgpu_device_interface) ||
                     (src->device == 0) ||
                     (src->host_dirty() && src->host != nullptr);
    bool to_host = !dst_device_interface;

    halide_abort_if_false(user_context, from_host || src->device);
    halide_abort_if_false(user_context, to_host || dst->device);

    device_copy c = make_buffer_copy(src, from_host, dst, to_host);

    int err = halide_error_code_success;
    {
        WgpuContext context(user_context);
        if (context.error_code) {
            return context.error_code;
        }

        ErrorScope error_scope(user_context, context.device);

        err = do_multidimensional_copy(user_context, &context, c,
                                       c.src_begin, 0, dst->dimensions,
                                       from_host, to_host);
        if (err == halide_error_code_success) {
            err = error_scope.wait();
        }
    }

    return err;
}

WEAK int halide_webgpu_copy_to_device(void *user_context,
                                      halide_buffer_t *buf) {
    return halide_webgpu_buffer_copy(user_context, buf,
                                     &webgpu_device_interface, buf);
}

WEAK int halide_webgpu_copy_to_host(void *user_context, halide_buffer_t *buf) {
    return halide_webgpu_buffer_copy(user_context, buf,
                                     nullptr, buf);
}

namespace {

WEAK int webgpu_device_crop_from_offset(void *user_context,
                                        const struct halide_buffer_t *src,
                                        int64_t offset,
                                        struct halide_buffer_t *dst) {
    WgpuContext context(user_context);
    if (context.error_code) {
        return context.error_code;
    }

    dst->device_interface = src->device_interface;

    WgpuBufferHandle *src_handle = (WgpuBufferHandle *)src->device;
    wgpuBufferReference(src_handle->buffer);

    WgpuBufferHandle *dst_handle =
        (WgpuBufferHandle *)malloc(sizeof(WgpuBufferHandle));
    dst_handle->buffer = src_handle->buffer;
    dst_handle->offset = src_handle->offset + offset;
    dst->device = (uint64_t)dst_handle;

    return halide_error_code_success;
}

}  // namespace

WEAK int halide_webgpu_device_crop(void *user_context,
                                   const struct halide_buffer_t *src,
                                   struct halide_buffer_t *dst) {
    const int64_t offset = calc_device_crop_byte_offset(src, dst);
    return webgpu_device_crop_from_offset(user_context, src, offset, dst);
}

WEAK int halide_webgpu_device_slice(void *user_context,
                                    const struct halide_buffer_t *src,
                                    int slice_dim,
                                    int slice_pos,
                                    struct halide_buffer_t *dst) {
    const int64_t offset =
        calc_device_slice_byte_offset(src, slice_dim, slice_pos);
    return webgpu_device_crop_from_offset(user_context, src, offset, dst);
}

WEAK int halide_webgpu_device_release_crop(void *user_context,
                                           struct halide_buffer_t *buf) {
    WgpuBufferHandle *handle = (WgpuBufferHandle *)buf->device;

    debug(user_context)
        << "WGPU: halide_webgpu_device_release_crop (user_context: "
        << user_context << ", buf: " << buf << ") WGPUBuffer: "
        << handle->buffer << " offset: " << handle->offset << "\n";

    WgpuContext context(user_context);
    if (context.error_code) {
        return context.error_code;
    }

    wgpuBufferRelease(handle->buffer);
    free(handle);
    buf->device = 0;

    return halide_error_code_success;
}

WEAK int halide_webgpu_wrap_native(void *user_context, struct halide_buffer_t *buf, uint64_t mem) {
    // TODO: Implement this.
    // See https://github.com/halide/Halide/issues/7250
    halide_debug_assert(user_context, false && "unimplemented");
    return halide_error_code_unimplemented;
}

WEAK int halide_webgpu_detach_native(void *user_context, halide_buffer_t *buf) {
    // TODO: Implement this.
    // See https://github.com/halide/Halide/issues/7250
    halide_debug_assert(user_context, false && "unimplemented");
    return halide_error_code_unimplemented;
}

WEAK int halide_webgpu_initialize_kernels(void *user_context, void **state_ptr, const char *src, int size) {
    debug(user_context)
        << "WGPU: halide_webgpu_initialize_kernels (user_context: " << user_context
        << ", state_ptr: " << state_ptr
        << ", program: " << (void *)src
        << ", size: " << size << ")\n";

    WgpuContext context(user_context);
    if (context.error_code) {
        return context.error_code;
    }

    // Get the shader module from the cache, compiling it if necessary.
    WGPUShaderModule shader_module;
    if (!shader_cache.kernel_state_setup(
            user_context, state_ptr, context.device, shader_module,
            [&]() -> WGPUShaderModule {
                ErrorScope error_scope(user_context, context.device);

                WGPUShaderModuleWGSLDescriptor wgsl_desc{};
                wgsl_desc.chain.next = nullptr;
                wgsl_desc.chain.sType = WGPUSType_ShaderModuleWGSLDescriptor;
                wgsl_desc.code = src;
                WGPUShaderModuleDescriptor desc{};
                desc.nextInChain = (const WGPUChainedStruct *)(&wgsl_desc);
                desc.label = nullptr;
                WGPUShaderModule shader_module =
                    wgpuDeviceCreateShaderModule(context.device, &desc);

                int error_code = error_scope.wait();
                if (error_code != halide_error_code_success) {
                    return nullptr;  // from the lambda
                }

                return shader_module;
            })) {
        return halide_error_code_generic_error;
    }
    halide_abort_if_false(user_context, shader_module != nullptr);

    return halide_error_code_success;
}

WEAK void halide_webgpu_finalize_kernels(void *user_context, void *state_ptr) {
    debug(user_context)
        << "WGPU: halide_webgpu_finalize_kernels (user_context: "
        << user_context << ", state_ptr: " << state_ptr << "\n";

    WgpuContext context(user_context);
    if (context.error_code == halide_error_code_success) {
        shader_cache.release_hold(user_context, context.device, state_ptr);
    }
}

WEAK int halide_webgpu_run(void *user_context,
                           void *state_ptr,
                           const char *entry_name,
                           int groupsX, int groupsY, int groupsZ,
                           int threadsX, int threadsY, int threadsZ,
                           int workgroup_mem_bytes,
                           halide_type_t arg_types[],
                           void *args[],
                           int8_t arg_is_buffer[]) {
    debug(user_context)
        << "WGPU: halide_webgpu_run (user_context: " << user_context << ", "
        << "entry: " << entry_name << ", "
        << "groups: " << groupsX << "x" << groupsY << "x" << groupsZ << ", "
        << "threads: " << threadsX << "x" << threadsY << "x" << threadsZ << ", "
        << "workgroup_mem: " << workgroup_mem_bytes << "\n";

    WgpuContext context(user_context);
    if (context.error_code) {
        return context.error_code;
    }

    ErrorScope error_scope(user_context, context.device);

    WGPUShaderModule shader_module = nullptr;
    bool found = shader_cache.lookup(context.device, state_ptr, shader_module);
    halide_abort_if_false(user_context, found && shader_module != nullptr);

    // Create the compute pipeline.
    WGPUConstantEntry overrides[4] = {
        {nullptr, "wgsize_x", (double)threadsX},
        {nullptr, "wgsize_y", (double)threadsY},
        {nullptr, "wgsize_z", (double)threadsZ},
        {nullptr, "workgroup_mem_bytes", (double)workgroup_mem_bytes},
    };
    WGPUProgrammableStageDescriptor stage_desc{};
    stage_desc.nextInChain = nullptr;
    stage_desc.module = shader_module;
    stage_desc.entryPoint = entry_name;
    stage_desc.constantCount = 4;
    stage_desc.constants = overrides;

    WGPUComputePipelineDescriptor pipeline_desc{};
    pipeline_desc.nextInChain = nullptr;
    pipeline_desc.label = nullptr;
    pipeline_desc.layout = nullptr;
    pipeline_desc.compute = stage_desc;

    WGPUComputePipeline pipeline =
        wgpuDeviceCreateComputePipeline(context.device, &pipeline_desc);

    // Set up a compute shader dispatch command.
    WGPUCommandEncoder encoder =
        wgpuDeviceCreateCommandEncoder(context.device, nullptr);
    WGPUComputePassEncoder pass =
        wgpuCommandEncoderBeginComputePass(encoder, nullptr);
    wgpuComputePassEncoderSetPipeline(pass, pipeline);

    // Process function arguments.
    uint32_t num_args = 0;
    uint32_t num_buffers = 0;
    uint32_t uniform_size = 0;
    while (args[num_args] != nullptr) {
        if (arg_is_buffer[num_args]) {
            num_buffers++;
        } else {
            uint32_t arg_size = arg_types[num_args].bytes();
            halide_debug_assert(user_context, arg_size <= 4);

            // Round up to 4 bytes.
            arg_size = round_up_to_multiple_of_4(arg_size);

            uniform_size += arg_size;
        }
        num_args++;
    }
    if (num_buffers > 0) {
        // Set up a bind group entry for each buffer argument.
        WGPUBindGroupEntry *bind_group_entries =
            (WGPUBindGroupEntry *)malloc(
                num_buffers * sizeof(WGPUBindGroupEntry));
        for (uint32_t i = 0, b = 0; i < num_args; i++) {
            if (arg_is_buffer[i]) {
                halide_buffer_t *buffer = (halide_buffer_t *)args[i];
                WgpuBufferHandle *handle = (WgpuBufferHandle *)(buffer->device);
                WGPUBindGroupEntry entry{};
                entry.nextInChain = nullptr;
                entry.binding = i;
                entry.buffer = handle->buffer;
                entry.offset = handle->offset;
                entry.size = round_up_to_multiple_of_4(buffer->size_in_bytes());
                entry.sampler = nullptr;
                entry.textureView = nullptr;
                bind_group_entries[b] = entry;
                b++;
            }
        }

        // Create a bind group for the buffer arguments.
        WGPUBindGroupLayout layout =
            wgpuComputePipelineGetBindGroupLayout(pipeline, 0);
        WGPUBindGroupDescriptor bindgroup_desc{};
        bindgroup_desc.nextInChain = nullptr;
        bindgroup_desc.label = nullptr;
        bindgroup_desc.layout = layout;
        bindgroup_desc.entryCount = num_buffers;
        bindgroup_desc.entries = bind_group_entries;
        WGPUBindGroup bind_group =
            wgpuDeviceCreateBindGroup(context.device, &bindgroup_desc);
        wgpuComputePassEncoderSetBindGroup(pass, 0, bind_group, 0, nullptr);
        wgpuBindGroupRelease(bind_group);
        wgpuBindGroupLayoutRelease(layout);

        free(bind_group_entries);
    }
    if (num_args > num_buffers) {
        // Create a uniform buffer for the non-buffer arguments.
        WGPUBufferDescriptor desc{};
        desc.nextInChain = nullptr;
        desc.label = nullptr;
        desc.usage = WGPUBufferUsage_Uniform;
        desc.size = uniform_size;
        desc.mappedAtCreation = true;
        WGPUBuffer arg_buffer = wgpuDeviceCreateBuffer(context.device, &desc);

        // Write the argument values to the uniform buffer.
        uint32_t *arg_values =
            (uint32_t *)wgpuBufferGetMappedRange(arg_buffer, 0, uniform_size);
        for (uint32_t a = 0, i = 0; a < num_args; a++) {
            if (arg_is_buffer[a]) {
                continue;
            }

            halide_type_t arg_type = arg_types[a];
            halide_debug_assert(user_context, arg_type.lanes == 1);
            halide_debug_assert(user_context, arg_type.bits > 0);
            halide_debug_assert(user_context, arg_type.bits <= 32);

            void *arg_in = args[a];
            void *arg_out = &arg_values[i++];

            // Copy the argument value, expanding it to 32-bits.
            switch (arg_type.code) {
            case halide_type_float: {
                halide_debug_assert(user_context, arg_type.bits == 32);
                *(float *)arg_out = *(float *)arg_in;
                break;
            }
            case halide_type_int: {
                switch (arg_type.bits) {
                case 1: {
                    *(int32_t *)arg_out = *((int8_t *)arg_in);
                }
                case 8: {
                    *(int32_t *)arg_out = *((int8_t *)arg_in);
                    break;
                }
                case 16: {
                    *(int32_t *)arg_out = *((int16_t *)arg_in);
                    break;
                }
                case 32: {
                    *(int32_t *)arg_out = *((int32_t *)arg_in);
                    break;
                }
                default: {
                    halide_debug_assert(user_context, false);
                }
                }
                break;
            }
            case halide_type_uint: {
                switch (arg_type.bits) {
                case 1: {
                    *(uint32_t *)arg_out = *((uint8_t *)arg_in);
                }
                case 8: {
                    *(uint32_t *)arg_out = *((uint8_t *)arg_in);
                    break;
                }
                case 16: {
                    *(uint32_t *)arg_out = *((uint16_t *)arg_in);
                    break;
                }
                case 32: {
                    *(uint32_t *)arg_out = *((uint32_t *)arg_in);
                    break;
                }
                default: {
                    halide_debug_assert(user_context, false);
                }
                }
                break;
            }
            default: {
                halide_debug_assert(user_context, false && "unhandled type");
            }
            }
        }
        wgpuBufferUnmap(arg_buffer);

        // Create a bind group for the uniform buffer.
        WGPUBindGroupLayout layout =
            wgpuComputePipelineGetBindGroupLayout(pipeline, 1);
        WGPUBindGroupEntry entry{};
        entry.nextInChain = nullptr;
        entry.binding = 0;
        entry.buffer = arg_buffer;
        entry.offset = 0;
        entry.size = uniform_size;
        entry.sampler = nullptr;
        entry.textureView = nullptr;
        WGPUBindGroupDescriptor bindgroup_desc{};
        bindgroup_desc.nextInChain = nullptr;
        bindgroup_desc.label = nullptr;
        bindgroup_desc.layout = layout;
        bindgroup_desc.entryCount = 1;
        bindgroup_desc.entries = &entry;
        WGPUBindGroup bind_group =
            wgpuDeviceCreateBindGroup(context.device, &bindgroup_desc);
        wgpuComputePassEncoderSetBindGroup(pass, 1, bind_group, 0, nullptr);
        wgpuBindGroupRelease(bind_group);
        wgpuBindGroupLayoutRelease(layout);

        wgpuBufferRelease(arg_buffer);
    }

    wgpuComputePassEncoderDispatchWorkgroups(pass, groupsX, groupsY, groupsZ);
    wgpuComputePassEncoderEnd(pass);

    // Submit the compute command.
    WGPUCommandBuffer commands = wgpuCommandEncoderFinish(encoder, nullptr);
    wgpuQueueSubmit(context.queue, 1, &commands);

    wgpuCommandEncoderRelease(encoder);
    wgpuComputePipelineRelease(pipeline);

    return error_scope.wait();
}

WEAK const struct halide_device_interface_t *halide_webgpu_device_interface() {
    return &webgpu_device_interface;
}

namespace {
WEAK __attribute__((destructor)) void halide_webgpu_cleanup() {
    shader_cache.release_all(nullptr, wgpuShaderModuleRelease);
    halide_webgpu_device_release(nullptr);
}
}  // namespace

}  // extern "C" linkage

namespace Halide {
namespace Runtime {
namespace Internal {
namespace WebGPU {

WEAK halide_device_interface_impl_t webgpu_device_interface_impl = {
    halide_use_jit_module,
    halide_release_jit_module,
    halide_webgpu_device_malloc,
    halide_webgpu_device_free,
    halide_webgpu_device_sync,
    halide_webgpu_device_release,
    halide_webgpu_copy_to_host,
    halide_webgpu_copy_to_device,
    halide_webgpu_device_and_host_malloc,
    halide_webgpu_device_and_host_free,
    halide_webgpu_buffer_copy,
    halide_webgpu_device_crop,
    halide_webgpu_device_slice,
    halide_webgpu_device_release_crop,
    halide_webgpu_wrap_native,
    halide_webgpu_detach_native,
};

WEAK halide_device_interface_t webgpu_device_interface = {
    halide_device_malloc,
    halide_device_free,
    halide_device_sync,
    halide_device_release,
    halide_copy_to_host,
    halide_copy_to_device,
    halide_device_and_host_malloc,
    halide_device_and_host_free,
    halide_buffer_copy,
    halide_device_crop,
    halide_device_slice,
    halide_device_release_crop,
    halide_device_wrap_native,
    halide_device_detach_native,
    nullptr,
    &webgpu_device_interface_impl};

}  // namespace WebGPU
}  // namespace Internal
}  // namespace Runtime
}  // namespace Halide
back to top