Raw File
WasmExecutor.cpp
#include "WasmExecutor.h"

#include "CodeGen_WebAssembly.h"
#include "Error.h"
#include "Float16.h"
#include "Func.h"
#include "ImageParam.h"
#include "JITModule.h"
#ifdef WITH_V8
#include "LLVM_Headers.h"
#endif
#include "LLVM_Output.h"
#include "LLVM_Runtime_Linker.h"
#include "Target.h"

#include <cmath>
#include <mutex>
#include <sstream>
#include <vector>

// clang-format off
// These includes are order-dependent, don't let clang-format reorder them
#ifdef WITH_V8
#include "v8.h"
#include "libplatform/libplatform.h"
#endif  // WITH_V8
// clang-format on

// ---------------------

// Debugging the WebAssembly JIT support is usually disconnected from the rest of HL_DEBUG_CODEGEN
#define WASM_DEBUG_LEVEL 0

namespace {
struct debug_sink {
    inline debug_sink() {
    }

    template<typename T>
    inline debug_sink &operator<<(T &&x) {
        return *this;
    }
};
}  // namespace

#if WASM_DEBUG_LEVEL
#define wdebug(x) Halide::Internal::debug(((x) <= WASM_DEBUG_LEVEL) ? 0 : 255)
#else
#define wdebug(x) debug_sink()
#endif

// ---------------------

namespace {

template<typename T>
inline T align_up(T p, int alignment = 32) {
    return (p + alignment - 1) & ~(alignment - 1);
}

// Debugging our Malloc is extremely noisy and usually undesired

#define BDMALLOC_DEBUG_LEVEL 0
#if BDMALLOC_DEBUG_LEVEL
#define bddebug(x) Halide::Internal::debug(((x) <= BDMALLOC_DEBUG_LEVEL) ? 0 : 255)
#else
#define bddebug(x) debug_sink()
#endif

// BDMalloc aka BrainDeadMalloc. This is an *extremely* simple-minded implementation
// of malloc/free on top of a WasmMemoryObject, and is intended to be just barely adequate
// to allow Halide's JIT-based tests to pass. It is neither memory-efficient nor performant,
// nor has it been particularly well-vetted for potential buffer overruns and such.
class BDMalloc {
    struct Region {
        uint32_t size : 31;
        uint32_t used : 1;
    };

    uint32_t total_size = 0;
    std::map<uint32_t, Region> regions;

public:
    BDMalloc() = default;

    uint32_t get_total_size() const {
        return total_size;
    }

    bool inited() const {
        return total_size > 0;
    }

    void init(uint32_t total_size, uint32_t heap_start = 1) {
        this->total_size = total_size;
        regions.clear();

        internal_assert(heap_start < total_size);
        // Area before heap_start is permanently off-limits
        regions[0] = {heap_start, true};
        // Everything else is free
        regions[heap_start] = {total_size - heap_start, false};
    }

    void reset() {
        this->total_size = 0;
        regions.clear();
    }

    uint32_t alloc_region(uint32_t requested_size) {
        internal_assert(requested_size > 0);

        bddebug(1) << "begin alloc_region " << requested_size << "\n";
        validate();

        // TODO: this would be faster with a basic free list,
        // but for most Halide test code, there aren't enough allocations
        // for this to be worthwhile, or at least that's my observation from
        // a first run-thru. Consider adding it if any of the JIT-based tests
        // seem unreasonably slow; as it is, a linear search for the first free
        // block of adequate size is (apparently) performant enough.

        // alignment and min-block-size are the same for our purposes here.
        constexpr uint32_t kAlignment = 32;
        const uint32_t size = std::max(align_up((uint32_t)requested_size, kAlignment), kAlignment);

        constexpr uint32_t kMaxAllocSize = 0x7fffffff;
        internal_assert(size <= kMaxAllocSize);
        bddebug(2) << "size -> " << size << "\n";

        for (auto it = regions.begin(); it != regions.end(); it++) {
            const uint32_t start = it->first;
            Region &r = it->second;
            if (!r.used && r.size >= size) {
                bddebug(2) << "alloc @ " << start << "," << (uint32_t)r.size << "\n";
                if (r.size > size + kAlignment) {
                    // Split the block
                    const uint32_t r2_start = start + size;
                    const uint32_t r2_size = r.size - size;
                    regions[r2_start] = {r2_size, false};
                    r.size = size;
                    bddebug(2) << "split: r-> " << start << "," << (uint32_t)r.size << "," << (start + r.size) << "\n";
                    bddebug(2) << "split: r2-> " << r2_start << "," << r2_size << "," << (r2_start + r2_size) << "\n";
                }
                // Just return the block
                r.used = true;
                bddebug(1) << "end alloc_region " << requested_size << "\n";
                validate();
                return start;
            }
        }
        bddebug(1) << "fail alloc_region " << requested_size << "\n";
        validate();
        return 0;
    }

    void free_region(uint32_t start) {
        bddebug(1) << "begin free_region " << start << "\n";
        validate();

        // Can't free region at zero
        if (!start) {
            return;
        }

        internal_assert(start > 0);
        auto it = regions.find(start);
        internal_assert(it != regions.end());
        internal_assert(it->second.used);
        it->second.used = false;
        // If prev region is free, combine with it
        if (it != regions.begin()) {
            auto prev = std::prev(it);
            if (!prev->second.used) {
                bddebug(2) << "combine prev: " << prev->first << " w/ " << it->first << "\n";
                prev->second.size += it->second.size;
                regions.erase(it);
                it = prev;
            }
        }
        // If next region is free, combine with it
        auto next = std::next(it);
        if (next != regions.end() && !next->second.used) {
            bddebug(2) << "combine next: " << next->first << " w/ " << it->first << " "
                       << "\n";
            it->second.size += next->second.size;
            regions.erase(next);
        }

        bddebug(1) << "end free_region " << start << "\n";
        validate();
    }

    void grow_total_size(uint32_t new_total_size) {
        bddebug(1) << "begin grow_total_size " << new_total_size << "\n";
        validate();

        internal_assert(new_total_size > total_size);
        auto it = regions.rbegin();
        const uint32_t start = it->first;
        Region &r = it->second;
        uint32_t r_end = start + r.size;
        internal_assert(r_end == total_size);
        uint32_t delta = new_total_size - r_end;
        if (r.used) {
            // Add a free region after the last one
            regions[r_end] = {delta, false};
        } else {
            // Just extend the last (free) region
            r.size += delta;
        }
        internal_assert(start + r.size == new_total_size);
        // bookkeeping
        total_size = new_total_size;

        bddebug(1) << "end grow_total_size " << new_total_size << "\n";
        validate();
    }

    void validate() const {
        internal_assert(total_size > 0);
#if (BDMALLOC_DEBUG_LEVEL >= 1) || (WASM_DEBUG_LEVEL >= 1)
        uint32_t prev_end = 0;
        bool prev_used = false;
        for (auto it : regions) {
            const uint32_t start = it.first;
            const Region &r = it.second;
            bddebug(2) << "R: " << start << ".." << (start + r.size - 1) << "," << r.used << "\n";
            internal_assert(start == prev_end) << "start " << start << " prev_end " << prev_end << "\n";
            // it's OK to have two used regions in a row, but not two free ones
            internal_assert(!(!prev_used && !r.used));
            prev_end = start + r.size;
            prev_used = r.used;
        }
        internal_assert(prev_end == total_size) << "prev_end " << prev_end << " total_size " << total_size << "\n";
        bddebug(2) << "\n";
#endif
    }
};

// ---------------------

// Trampolines do not use "_argv" as the suffix because
// that name may already exist and if so, will return an int
// instead of taking a pointer at the end of the args list to
// receive the result value.
static const char kTrampolineSuffix[] = "_trampoline";

using JITExternMap = std::map<std::string, Halide::JITExtern>;

}  // namespace

#ifdef WITH_V8

#define V8_API_VERSION ((V8_MAJOR_VERSION * 10) + V8_MINOR_VERSION)

static_assert(V8_API_VERSION >= 75,
              "Halide requires V8 v7.5 or later when compiling WITH_V8.");

namespace Halide {
namespace Internal {
namespace {

v8::Local<v8::String> NewLocalString(v8::Isolate *isolate, const char *s) {
#if V8_API_VERSION >= 76
    return v8::String::NewFromUtf8(isolate, s).ToLocalChecked();
#else
    return v8::String::NewFromUtf8(isolate, s);
#endif
}

using namespace v8;

using wasm32_ptr_t = int32_t;

const wasm32_ptr_t kMagicJitUserContextValue = -1;

#if WASM_DEBUG_LEVEL
void print_object_properties(Isolate *isolate, const Local<Value> &v) {
    Local<Context> context = isolate->GetCurrentContext();
    String::Utf8Value objascii(isolate, v);
    wdebug(0) << *objascii << "\n";

    if (v->IsObject()) {
        Local<Object> obj = v.As<Object>();
        Local<Array> properties = obj->GetPropertyNames(context).ToLocalChecked();
        int len = properties->Length();
        wdebug(0) << "Number of properties = " << len << ":\n";
        for (int i = 0; i < len; ++i) {
            const v8::Local<v8::Value> key = properties->Get(i);
            String::Utf8Value str(isolate, key);
            wdebug(0) << "\t" << i + 1 << ". " << *str << "\n";
        }
    }
}
#endif

inline constexpr int halide_type_code(halide_type_code_t code, int bits) {
    return ((int)code) | (bits << 8);
}

// dynamic_type_dispatch is a utility for functors that want to be able
// to dynamically dispatch a halide_type_t to type-specialized code.
// To use it, a functor must be a *templated* class, e.g.
//
//     template<typename T> class MyFunctor { int operator()(arg1, arg2...); };
//
// dynamic_type_dispatch() is called with a halide_type_t as the first argument,
// followed by the arguments to the Functor's operator():
//
//     auto result = dynamic_type_dispatch<MyFunctor>(some_halide_type, arg1, arg2);
//
// Note that this means that the functor must be able to instantiate its
// operator() for all the Halide scalar types; it also means that all those
// variants *will* be instantiated (increasing code size), so this approach
// should only be used when strictly necessary.
template<template<typename> class Functor, typename... Args>
auto dynamic_type_dispatch(const halide_type_t &type, Args &&... args) -> decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...)) {

#define HANDLE_CASE(CODE, BITS, TYPE)  \
    case halide_type_code(CODE, BITS): \
        return Functor<TYPE>()(std::forward<Args>(args)...);
    switch (halide_type_code((halide_type_code_t)type.code, type.bits)) {
        HANDLE_CASE(halide_type_bfloat, 16, bfloat16_t)
        HANDLE_CASE(halide_type_float, 16, float16_t)
        HANDLE_CASE(halide_type_float, 32, float)
        HANDLE_CASE(halide_type_float, 64, double)
        HANDLE_CASE(halide_type_int, 8, int8_t)
        HANDLE_CASE(halide_type_int, 16, int16_t)
        HANDLE_CASE(halide_type_int, 32, int32_t)
        HANDLE_CASE(halide_type_int, 64, int64_t)
        HANDLE_CASE(halide_type_uint, 1, bool)
        HANDLE_CASE(halide_type_uint, 8, uint8_t)
        HANDLE_CASE(halide_type_uint, 16, uint16_t)
        HANDLE_CASE(halide_type_uint, 32, uint32_t)
        HANDLE_CASE(halide_type_uint, 64, uint64_t)
        HANDLE_CASE(halide_type_handle, 64, void *)
    default:
        internal_error;
        using ReturnType = decltype(std::declval<Functor<uint8_t>>()(std::forward<Args>(args)...));
        return ReturnType();
    }
#undef HANDLE_CASE
}

// ------------------------------

template<typename T>
struct ExtractAndStoreScalar {
    void operator()(const Local<Context> &context, const Local<Value> &val, void *slot) {
        *(T *)slot = (T)val->NumberValue(context).ToChecked();
    }
};

template<>
inline void ExtractAndStoreScalar<float16_t>::operator()(const Local<Context> &context, const Local<Value> &val, void *slot) {
    float16_t f((double)val->NumberValue(context).ToChecked());
    *(uint16_t *)slot = f.to_bits();
}

template<>
inline void ExtractAndStoreScalar<bfloat16_t>::operator()(const Local<Context> &context, const Local<Value> &val, void *slot) {
    bfloat16_t b((double)val->NumberValue(context).ToChecked());
    *(uint16_t *)slot = b.to_bits();
}

template<>
inline void ExtractAndStoreScalar<void *>::operator()(const Local<Context> &context, const Local<Value> &val, void *slot) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
}

template<>
inline void ExtractAndStoreScalar<uint64_t>::operator()(const Local<Context> &context, const Local<Value> &val, void *slot) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
}

template<>
inline void ExtractAndStoreScalar<int64_t>::operator()(const Local<Context> &context, const Local<Value> &val, void *slot) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
}

// ------------------------------

template<typename T>
struct LoadAndReturnScalar {
    void operator()(const Local<Context> &context, const void *slot, ReturnValue<Value> val) {
        val.Set(*(const T *)slot);
    }
};

template<>
inline void LoadAndReturnScalar<float16_t>::operator()(const Local<Context> &context, const void *slot, ReturnValue<Value> val) {
    float16_t f = float16_t::make_from_bits(*(const uint16_t *)slot);
    val.Set((double)f);
}

template<>
inline void LoadAndReturnScalar<bfloat16_t>::operator()(const Local<Context> &context, const void *slot, ReturnValue<Value> val) {
    bfloat16_t b = bfloat16_t::make_from_bits(*(const uint16_t *)slot);
    val.Set((double)b);
}

template<>
inline void LoadAndReturnScalar<void *>::operator()(const Local<Context> &context, const void *slot, ReturnValue<Value> val) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
}

template<>
inline void LoadAndReturnScalar<uint64_t>::operator()(const Local<Context> &context, const void *slot, ReturnValue<Value> val) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
}

template<>
inline void LoadAndReturnScalar<int64_t>::operator()(const Local<Context> &context, const void *slot, ReturnValue<Value> val) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
}

// ------------------------------

template<typename T>
struct WrapScalar {
    Local<Value> operator()(const Local<Context> &context, const void *val_ptr) {
        double val = *(const T *)(val_ptr);
        Isolate *isolate = context->GetIsolate();
        return Number::New(isolate, val);
    }
};

template<>
inline Local<Value> WrapScalar<float16_t>::operator()(const Local<Context> &context, const void *val_ptr) {
    double val = (double)*(const uint16_t *)val_ptr;
    Isolate *isolate = context->GetIsolate();
    return Number::New(isolate, val);
}

template<>
inline Local<Value> WrapScalar<bfloat16_t>::operator()(const Local<Context> &context, const void *val_ptr) {
    double val = (double)*(const uint16_t *)val_ptr;
    Isolate *isolate = context->GetIsolate();
    return Number::New(isolate, val);
}

template<>
inline Local<Value> WrapScalar<void *>::operator()(const Local<Context> &context, const void *val_ptr) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
    return Local<Value>();
}

template<>
inline Local<Value> WrapScalar<uint64_t>::operator()(const Local<Context> &context, const void *val_ptr) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
    return Local<Value>();
}

template<>
inline Local<Value> WrapScalar<int64_t>::operator()(const Local<Context> &context, const void *val_ptr) {
    internal_error << "TODO: 64-bit slots aren't yet supported";
    return Local<Value>();
}

// ------------------------------

Local<Value> wrap_scalar(const Local<Context> &context, const Type &t, const void *val_ptr) {
    return dynamic_type_dispatch<WrapScalar>(t, context, val_ptr);
}

template<typename T>
Local<Value> wrap_scalar(const Local<Context> &context, const T &val) {
    return WrapScalar<T>()(context, &val);
}

// ---------------------------------

enum EmbedderDataSlots {
    // don't use slot 0
    kWasmMemoryObject = 1,
    kBDMallocPtr,
    kHeapBase,
    kJitUserContext,
    kString_buffer,
    kString_grow,
};

wasm32_ptr_t v8_WasmMemoryObject_malloc(const Local<Context> &context, size_t size) {
    Isolate *isolate = context->GetIsolate();

    BDMalloc *bdmalloc = (BDMalloc *)context->GetAlignedPointerFromEmbedderData(kBDMallocPtr);
    if (!bdmalloc->inited()) {
        int32_t heap_base = context->GetEmbedderData(kHeapBase)->Int32Value(context).ToChecked();

        Local<Object> memory_value = context->GetEmbedderData(kWasmMemoryObject).As<Object>();  // really a WasmMemoryObject
        Local<Object> buffer_string = context->GetEmbedderData(kString_buffer).As<Object>();
        Local<ArrayBuffer> wasm_memory = Local<ArrayBuffer>::Cast(memory_value->Get(context, buffer_string).ToLocalChecked());

        wdebug(0) << "heap_base is: " << heap_base << "\n";
        wdebug(0) << "initial memory size is: " << wasm_memory->ByteLength() << "\n";
        bdmalloc->init(wasm_memory->ByteLength(), heap_base);
    }

    wasm32_ptr_t p = bdmalloc->alloc_region(size);
    if (!p) {
        Local<Object> memory_value = context->GetEmbedderData(kWasmMemoryObject).As<Object>();  // really a WasmMemoryObject

        constexpr int kWasmPageSize = 65536;
        const int32_t pages_needed = (size + kWasmPageSize - 1) / 65536;
        wdebug(0) << "attempting to grow by pages: " << pages_needed << "\n";

        Local<Value> args[1] = {Integer::New(isolate, pages_needed)};
        int32_t result = memory_value
                             ->Get(context, context->GetEmbedderData(kString_grow))
                             .ToLocalChecked()
                             .As<Object>()
                             ->CallAsFunction(context, memory_value, 1, args)
                             .ToLocalChecked()
                             ->Int32Value(context)
                             .ToChecked();
        wdebug(0) << "grow result: " << result << "\n";
        internal_assert(result == (int)(bdmalloc->get_total_size() / kWasmPageSize));

        Local<Object> buffer_string = context->GetEmbedderData(kString_buffer).As<Object>();
        Local<ArrayBuffer> wasm_memory = Local<ArrayBuffer>::Cast(memory_value->Get(context, buffer_string).ToLocalChecked());
        wdebug(0) << "New ArrayBuffer size is: " << wasm_memory->ByteLength() << "\n";

        bdmalloc->grow_total_size(wasm_memory->ByteLength());
        p = bdmalloc->alloc_region(size);
    }

    wdebug(2) << "allocation of " << size << " at: " << p << "\n";
    return p;
}

void v8_WasmMemoryObject_free(const Local<Context> &context, wasm32_ptr_t ptr) {
    wdebug(2) << "freeing ptr at: " << ptr << "\n";
    BDMalloc *bdmalloc = (BDMalloc *)context->GetAlignedPointerFromEmbedderData(kBDMallocPtr);
    bdmalloc->free_region(ptr);
}

uint8_t *get_wasm_memory_base(const Local<Context> &context) {
    Local<Object> memory_value = context->GetEmbedderData(kWasmMemoryObject).As<Object>();  // really a WasmMemoryObject
    Local<ArrayBuffer> wasm_memory = Local<ArrayBuffer>::Cast(memory_value->Get(context, context->GetEmbedderData(kString_buffer)).ToLocalChecked());
    uint8_t *p = (uint8_t *)wasm_memory->GetContents().Data();
    return p;
}

struct wasm_halide_buffer_t {
    uint64_t device;
    wasm32_ptr_t device_interface;  // halide_device_interface_t*
    wasm32_ptr_t host;              // uint8_t*
    uint64_t flags;
    halide_type_t type;
    int32_t dimensions;
    wasm32_ptr_t dim;      // halide_dimension_t*
    wasm32_ptr_t padding;  // always zero
};

void dump_hostbuf(const Local<Context> &context, const halide_buffer_t *buf, const std::string &label) {
#if WASM_DEBUG_LEVEL >= 2
    const halide_dimension_t *dim = buf->dim;
    const uint8_t *host = buf->host;

    wdebug(0) << label << " = " << (void *)buf << " = {\n";
    wdebug(0) << "  device = " << buf->device << "\n";
    wdebug(0) << "  device_interface = " << buf->device_interface << "\n";
    wdebug(0) << "  host = " << (void *)host << " = {\n";
    if (host) {
        wdebug(0) << "    " << (int)host[0] << ", " << (int)host[1] << ", " << (int)host[2] << ", " << (int)host[3] << "...\n";
    }
    wdebug(0) << "  }\n";
    wdebug(0) << "  flags = " << buf->flags << "\n";
    wdebug(0) << "  type = " << (int)buf->type.code << "," << (int)buf->type.bits << "," << buf->type.lanes << "\n";
    wdebug(0) << "  dimensions = " << buf->dimensions << "\n";
    wdebug(0) << "  dim = " << (void *)buf->dim << " = {\n";
    for (int i = 0; i < buf->dimensions; i++) {
        const auto &d = dim[i];
        wdebug(0) << "    {" << d.min << "," << d.extent << "," << d.stride << "," << d.flags << "},\n";
    }
    wdebug(0) << "  }\n";
    wdebug(0) << "  padding = " << buf->padding << "\n";
    wdebug(0) << "}\n";
#endif
}

void dump_wasmbuf(const Local<Context> &context, wasm32_ptr_t buf_ptr, const std::string &label) {
#if WASM_DEBUG_LEVEL >= 2
    internal_assert(buf_ptr);

    uint8_t *base = get_wasm_memory_base(context);
    wasm_halide_buffer_t *buf = (wasm_halide_buffer_t *)(base + buf_ptr);
    halide_dimension_t *dim = buf->dim ? (halide_dimension_t *)(base + buf->dim) : nullptr;
    uint8_t *host = buf->host ? (base + buf->host) : nullptr;

    wdebug(0) << label << " = " << buf_ptr << " -> " << (void *)buf << " = {\n";
    wdebug(0) << "  device = " << buf->device << "\n";
    wdebug(0) << "  device_interface = " << buf->device_interface << "\n";
    wdebug(0) << "  host = " << buf->host << " -> " << (void *)host << " = {\n";
    if (host) {
        wdebug(0) << "    " << (int)host[0] << ", " << (int)host[1] << ", " << (int)host[2] << ", " << (int)host[3] << "...\n";
    }
    wdebug(0) << "  }\n";
    wdebug(0) << "  flags = " << buf->flags << "\n";
    wdebug(0) << "  type = " << (int)buf->type.code << "," << (int)buf->type.bits << "," << buf->type.lanes << "\n";
    wdebug(0) << "  dimensions = " << buf->dimensions << "\n";
    wdebug(0) << "  dim = " << buf->dim << " -> " << (void *)dim << " = {\n";
    for (int i = 0; i < buf->dimensions; i++) {
        const auto &d = dim[i];
        wdebug(0) << "    {" << d.min << "," << d.extent << "," << d.stride << "," << d.flags << "},\n";
    }
    wdebug(0) << "  }\n";
    wdebug(0) << "  padding = " << buf->padding << "\n";
    wdebug(0) << "}\n";
#endif
}

static_assert(sizeof(halide_type_t) == 4, "halide_type_t");
static_assert(sizeof(halide_dimension_t) == 16, "halide_dimension_t");
static_assert(sizeof(wasm_halide_buffer_t) == 40, "wasm_halide_buffer_t");

// Given a halide_buffer_t on the host, allocate a wasm_halide_buffer_t in wasm
// memory space and copy all relevant data. The resulting buf is laid out in
// contiguous memory, and can be free with a single free().
wasm32_ptr_t hostbuf_to_wasmbuf(const Local<Context> &context, const halide_buffer_t *src) {
    static_assert(sizeof(halide_type_t) == 4, "halide_type_t");
    static_assert(sizeof(halide_dimension_t) == 16, "halide_dimension_t");
    static_assert(sizeof(wasm_halide_buffer_t) == 40, "wasm_halide_buffer_t");

    wdebug(0) << "\nhostbuf_to_wasmbuf:\n";
    dump_hostbuf(context, src, "src");

    internal_assert(src->device == 0);
    internal_assert(src->device_interface == nullptr);

    // Assume our malloc() has everything 32-byte aligned,
    // and insert enough padding for host to also be 32-byte aligned.
    const size_t dims_size_in_bytes = sizeof(halide_dimension_t) * src->dimensions;
    const size_t dims_offset = sizeof(wasm_halide_buffer_t);
    const size_t mem_needed_base = sizeof(wasm_halide_buffer_t) + dims_size_in_bytes;
    const size_t host_offset = align_up(mem_needed_base);
    const size_t host_size_in_bytes = src->size_in_bytes();
    const size_t mem_needed = host_offset + host_size_in_bytes;

    const wasm32_ptr_t dst_ptr = v8_WasmMemoryObject_malloc(context, mem_needed);
    internal_assert(dst_ptr);

    uint8_t *base = get_wasm_memory_base(context);

    wasm_halide_buffer_t *dst = (wasm_halide_buffer_t *)(base + dst_ptr);
    dst->device = 0;
    dst->device_interface = 0;
    dst->host = src->host ? (dst_ptr + host_offset) : 0;
    dst->flags = src->flags;
    dst->type = src->type;
    dst->dimensions = src->dimensions;
    dst->dim = src->dimensions ? (dst_ptr + dims_offset) : 0;
    dst->padding = 0;

    if (src->dim) {
        memcpy(base + dst->dim, src->dim, dims_size_in_bytes);
    }
    if (src->host) {
        memcpy(base + dst->host, src->host, host_size_in_bytes);
    }

    dump_wasmbuf(context, dst_ptr, "dst");

    return dst_ptr;
}

// Given a pointer to a wasm_halide_buffer_t in wasm memory space,
// allocate a Buffer<> on the host and copy all relevant data.
void wasmbuf_to_hostbuf(const Local<Context> &context, wasm32_ptr_t src_ptr, Halide::Runtime::Buffer<> &dst) {
    wdebug(0) << "\nwasmbuf_to_hostbuf:\n";
    dump_wasmbuf(context, src_ptr, "src");

    internal_assert(src_ptr);

    uint8_t *base = get_wasm_memory_base(context);

    wasm_halide_buffer_t *src = (wasm_halide_buffer_t *)(base + src_ptr);

    internal_assert(src->device == 0);
    internal_assert(src->device_interface == 0);

    halide_buffer_t dst_tmp;
    dst_tmp.device = 0;
    dst_tmp.device_interface = 0;
    dst_tmp.host = nullptr;  // src->host ? (base + src->host) : nullptr;
    dst_tmp.flags = src->flags;
    dst_tmp.type = src->type;
    dst_tmp.dimensions = src->dimensions;
    dst_tmp.dim = src->dim ? (halide_dimension_t *)(base + src->dim) : nullptr;
    dst_tmp.padding = 0;

    dump_hostbuf(context, &dst_tmp, "dst_tmp");

    dst = Halide::Runtime::Buffer<>(dst_tmp);
    if (src->host) {
        // Don't use dst.copy(); it can tweak strides in ways that matter.
        dst.allocate();
        const size_t host_size_in_bytes = dst.raw_buffer()->size_in_bytes();
        memcpy(dst.raw_buffer()->host, base + src->host, host_size_in_bytes);
    }
    dump_hostbuf(context, dst.raw_buffer(), "dst");
}

// Given a wasm_halide_buffer_t, copy possibly-changed data into a halide_buffer_t.
// Both buffers are asserted to match in type and dimensions.
void copy_wasmbuf_to_existing_hostbuf(const Local<Context> &context, wasm32_ptr_t src_ptr, halide_buffer_t *dst) {
    internal_assert(src_ptr && dst);

    wdebug(0) << "\ncopy_wasmbuf_to_existing_hostbuf:\n";
    dump_wasmbuf(context, src_ptr, "src");

    uint8_t *base = get_wasm_memory_base(context);

    wasm_halide_buffer_t *src = (wasm_halide_buffer_t *)(base + src_ptr);
    internal_assert(src->device == 0);
    internal_assert(src->device_interface == 0);
    internal_assert(src->dimensions == dst->dimensions);
    internal_assert(src->type == dst->type);

    dump_hostbuf(context, dst, "dst_pre");

    if (src->dimensions) {
        memcpy(dst->dim, base + src->dim, sizeof(halide_dimension_t) * src->dimensions);
    }
    if (src->host) {
        size_t host_size_in_bytes = dst->size_in_bytes();
        memcpy(dst->host, base + src->host, host_size_in_bytes);
    }

    dst->device = 0;
    dst->device_interface = 0;
    dst->flags = src->flags;

    dump_hostbuf(context, dst, "dst_post");
}

// Given a halide_buffer_t, copy possibly-changed data into a wasm_halide_buffer_t.
// Both buffers are asserted to match in type and dimensions.
void copy_hostbuf_to_existing_wasmbuf(const Local<Context> &context, const halide_buffer_t *src, wasm32_ptr_t dst_ptr) {
    internal_assert(src && dst_ptr);

    wdebug(0) << "\ncopy_hostbuf_to_existing_wasmbuf:\n";
    dump_hostbuf(context, src, "src");

    uint8_t *base = get_wasm_memory_base(context);

    wasm_halide_buffer_t *dst = (wasm_halide_buffer_t *)(base + dst_ptr);
    internal_assert(src->device == 0);
    internal_assert(src->device_interface == 0);
    internal_assert(src->dimensions == dst->dimensions);
    internal_assert(src->type == dst->type);

    dump_wasmbuf(context, dst_ptr, "dst_pre");

    if (src->dimensions) {
        memcpy(base + dst->dim, src->dim, sizeof(halide_dimension_t) * src->dimensions);
    }
    if (src->host) {
        size_t host_size_in_bytes = src->size_in_bytes();
        memcpy(base + dst->host, src->host, host_size_in_bytes);
    }

    dst->device = 0;
    dst->device_interface = 0;
    dst->flags = src->flags;

    dump_wasmbuf(context, dst_ptr, "dst_post");
}

JITUserContext *check_jit_user_context(JITUserContext *jit_user_context) {
    user_assert(!jit_user_context->handlers.custom_malloc &&
                !jit_user_context->handlers.custom_free)
        << "The WebAssembly JIT cannot support set_custom_allocator()";
    user_assert(!jit_user_context->handlers.custom_do_task)
        << "The WebAssembly JIT cannot support set_custom_do_task()";
    user_assert(!jit_user_context->handlers.custom_do_par_for)
        << "The WebAssembly JIT cannot support set_custom_do_par_for()";
    user_assert(!jit_user_context->handlers.custom_get_symbol &&
                !jit_user_context->handlers.custom_load_library &&
                !jit_user_context->handlers.custom_get_library_symbol)
        << "The WebAssembly JIT cannot support custom_get_symbol, custom_load_library, or custom_get_library_symbol.";
    return jit_user_context;
}

// Some internal code can call halide_error(null, ...), so this needs to be resilient to that.
// Callers must expect null and not crash.
JITUserContext *get_jit_user_context(const Local<Context> &context, const Local<Value> &arg) {
    int32_t ucon_magic = arg->Int32Value(context).ToChecked();
    if (ucon_magic == 0) {
        return nullptr;
    }
    internal_assert(ucon_magic == kMagicJitUserContextValue);
    JITUserContext *jit_user_context = (JITUserContext *)context->GetAlignedPointerFromEmbedderData(kJitUserContext);
    internal_assert(jit_user_context);
    return jit_user_context;
}

void wasm_jit_halide_print_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_assert(args.Length() == 2);

    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    JITUserContext *jit_user_context = get_jit_user_context(context, args[0]);
    const int32_t str_address = args[1]->Int32Value(context).ToChecked();

    uint8_t *p = get_wasm_memory_base(context);
    const char *str = (const char *)p + str_address;

    if (jit_user_context && jit_user_context->handlers.custom_print != NULL) {
        (*jit_user_context->handlers.custom_print)(jit_user_context, str);
        debug(0) << str;
    } else {
        std::cout << str;
    }
}

void wasm_jit_halide_error_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_assert(args.Length() == 2);

    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    JITUserContext *jit_user_context = get_jit_user_context(context, args[0]);
    const int32_t str_address = args[1]->Int32Value(context).ToChecked();

    uint8_t *p = get_wasm_memory_base(context);
    const char *str = (const char *)p + str_address;

    if (jit_user_context && jit_user_context->handlers.custom_error != NULL) {
        (*jit_user_context->handlers.custom_error)(jit_user_context, str);
    } else {
        halide_runtime_error << str;
    }
}

void wasm_jit_halide_trace_helper_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_assert(args.Length() == 12);
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    uint8_t *base = get_wasm_memory_base(context);

    JITUserContext *jit_user_context = get_jit_user_context(context, args[0]);

    const wasm32_ptr_t func_name_ptr = args[1]->Int32Value(context).ToChecked();
    const wasm32_ptr_t value_ptr = args[2]->Int32Value(context).ToChecked();
    const wasm32_ptr_t coordinates_ptr = args[3]->Int32Value(context).ToChecked();
    const int type_code = args[4]->Int32Value(context).ToChecked();
    const int type_bits = args[5]->Int32Value(context).ToChecked();
    const int type_lanes = args[6]->Int32Value(context).ToChecked();
    const int trace_code = args[7]->Int32Value(context).ToChecked();
    const int parent_id = args[8]->Int32Value(context).ToChecked();
    const int value_index = args[9]->Int32Value(context).ToChecked();
    const int dimensions = args[10]->Int32Value(context).ToChecked();
    const wasm32_ptr_t trace_tag_ptr = args[11]->Int32Value(context).ToChecked();

    internal_assert(dimensions >= 0 && dimensions < 1024);  // not a hard limit, just a sanity check

    halide_trace_event_t event;
    event.func = (const char *)(base + func_name_ptr);
    event.value = value_ptr ? ((void *)(base + value_ptr)) : nullptr;
    event.coordinates = coordinates_ptr ? ((int32_t *)(base + coordinates_ptr)) : nullptr;
    event.trace_tag = (const char *)(base + trace_tag_ptr);
    event.type.code = (halide_type_code_t)type_code;
    event.type.bits = (uint8_t)type_bits;
    event.type.lanes = (uint16_t)type_lanes;
    event.event = (halide_trace_event_code_t)trace_code;
    event.parent_id = parent_id;
    event.value_index = value_index;
    event.dimensions = dimensions;

    int result = 0;
    if (jit_user_context && jit_user_context->handlers.custom_trace != NULL) {
        result = (*jit_user_context->handlers.custom_trace)(jit_user_context, &event);
    } else {
        debug(0) << "Dropping trace event due to lack of trace handler.\n";
    }

    args.GetReturnValue().Set(wrap_scalar(context, result));
}

// TODO: vector codegen can underead allocated buffers; we need to deliberately
// allocate extra and return a pointer partway in to avoid out-of-bounds access
// failures. https://github.com/halide/Halide/issues/3738
constexpr size_t kExtraMallocSlop = 32;

void wasm_jit_malloc_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    HandleScope scope(isolate);
    Local<Context> context = isolate->GetCurrentContext();

    size_t size = args[0]->Int32Value(context).ToChecked() + kExtraMallocSlop;
    wasm32_ptr_t p = v8_WasmMemoryObject_malloc(context, size);
    if (p) p += kExtraMallocSlop;
    args.GetReturnValue().Set(wrap_scalar(context, p));
}

void wasm_jit_free_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    HandleScope scope(isolate);
    Local<Context> context = isolate->GetCurrentContext();
    wasm32_ptr_t p = args[0]->Int32Value(context).ToChecked();
    if (p) p -= kExtraMallocSlop;
    v8_WasmMemoryObject_free(context, p);
}

void wasm_jit_abort_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    abort();
}

void wasm_jit_strlen_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const int32_t s = args[0]->Int32Value(context).ToChecked();

    uint8_t *base = get_wasm_memory_base(context);
    int32_t r = strlen((char *)base + s);

    args.GetReturnValue().Set(wrap_scalar(context, r));
}

void wasm_jit_write_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_error << "WebAssembly JIT does not yet support the write() call.";
}

void wasm_jit_getenv_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const int32_t s = args[0]->Int32Value(context).ToChecked();

    uint8_t *base = get_wasm_memory_base(context);
    char *e = getenv((char *)base + s);

    // TODO: this string is leaked
    if (e) {
        wasm32_ptr_t r = v8_WasmMemoryObject_malloc(context, strlen(e) + 1);
        strcpy((char *)base + r, e);
        args.GetReturnValue().Set(wrap_scalar(context, r));
    } else {
        args.GetReturnValue().Set(wrap_scalar<wasm32_ptr_t>(context, 0));
    }
}

void wasm_jit_memcpy_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const int32_t dst = args[0]->Int32Value(context).ToChecked();
    const int32_t src = args[1]->Int32Value(context).ToChecked();
    const int32_t n = args[2]->Int32Value(context).ToChecked();

    uint8_t *base = get_wasm_memory_base(context);

    memcpy(base + dst, base + src, n);

    args.GetReturnValue().Set(wrap_scalar(context, dst));
}

void wasm_jit_fopen_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_error << "WebAssembly JIT does not yet support the fopen() call.";
}

void wasm_jit_fileno_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_error << "WebAssembly JIT does not yet support the fileno() call.";
}

void wasm_jit_fclose_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_error << "WebAssembly JIT does not yet support the fclose() call.";
}

void wasm_jit_fwrite_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    internal_error << "WebAssembly JIT does not yet support the fwrite() call.";
}

void wasm_jit_memset_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const int32_t s = args[0]->Int32Value(context).ToChecked();
    const int32_t c = args[1]->Int32Value(context).ToChecked();
    const int32_t n = args[2]->Int32Value(context).ToChecked();

    uint8_t *base = get_wasm_memory_base(context);
    memset(base + s, c, n);

    args.GetReturnValue().Set(wrap_scalar(context, s));
}

void wasm_jit_memcmp_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const int32_t s1 = args[0]->Int32Value(context).ToChecked();
    const int32_t s2 = args[1]->Int32Value(context).ToChecked();
    const int32_t n = args[2]->Int32Value(context).ToChecked();

    uint8_t *base = get_wasm_memory_base(context);

    int r = memcmp(base + s1, base + s2, n);

    args.GetReturnValue().Set(wrap_scalar(context, r));
}

void wasm_jit___cxa_atexit_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    // nothing
}

void wasm_jit___extendhfsf2_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const uint16_t in = args[0]->NumberValue(context).ToChecked();
    const float out = (float)float16_t::make_from_bits(in);

    args.GetReturnValue().Set(wrap_scalar(context, out));
}

void wasm_jit___truncsfhf2_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const float in = args[0]->NumberValue(context).ToChecked();
    const uint16_t out = float16_t(in).to_bits();

    args.GetReturnValue().Set(wrap_scalar(context, out));
}

template<typename T, T some_func(T)>
void wasm_jit_posix_math_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const T in = args[0]->NumberValue(context).ToChecked();
    const T out = some_func(in);

    args.GetReturnValue().Set(wrap_scalar(context, out));
}

template<typename T, T some_func(T, T)>
void wasm_jit_posix_math2_callback(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    Local<Context> context = isolate->GetCurrentContext();
    HandleScope scope(isolate);

    const T in1 = args[0]->NumberValue(context).ToChecked();
    const T in2 = args[1]->NumberValue(context).ToChecked();
    const T out = some_func(in1, in2);

    args.GetReturnValue().Set(wrap_scalar(context, out));
}

enum ExternWrapperFieldSlots {
    kTrampolineWrap,
    kArgTypesWrap
};

// Use a POD here so we can stuff it all into an ArrayBuffer to avoid having
// to worry about lifetime management
struct ExternArgType {
    halide_type_t type;
    bool is_void;
    bool is_buffer;
};

void v8_extern_wrapper(const v8::FunctionCallbackInfo<v8::Value> &args) {
    Isolate *isolate = args.GetIsolate();
    HandleScope scope(isolate);
    Local<Context> context = isolate->GetCurrentContext();
    Local<Object> wrapper_data = args.Data()->ToObject(context).ToLocalChecked();

    Local<External> trampoline_wrap = Local<External>::Cast(wrapper_data->GetInternalField(kTrampolineWrap));
    Local<ArrayBuffer> arg_types_wrap = Local<ArrayBuffer>::Cast(wrapper_data->GetInternalField(kArgTypesWrap));

    using TrampolineFn = void (*)(void **);
    TrampolineFn trampoline = (TrampolineFn)trampoline_wrap->Value();

    size_t arg_types_len = (arg_types_wrap->ByteLength() / sizeof(ExternArgType)) - 1;
    const ExternArgType *arg_types = (const ExternArgType *)arg_types_wrap->GetContents().Data();
    const ExternArgType ret_type = *arg_types++;

    // There's wasted space here, but that's ok.
    std::vector<Halide::Runtime::Buffer<>> buffers(arg_types_len);
    std::vector<uint64_t> scalars(arg_types_len);
    std::vector<void *> trampoline_args(arg_types_len);

    for (size_t i = 0; i < arg_types_len; ++i) {
        if (arg_types[i].is_buffer) {
            const wasm32_ptr_t buf_ptr = args[i]->Int32Value(context).ToChecked();
            wasmbuf_to_hostbuf(context, buf_ptr, buffers[i]);
            trampoline_args[i] = buffers[i].raw_buffer();
        } else {
            dynamic_type_dispatch<ExtractAndStoreScalar>(arg_types[i].type, context, args[i], &scalars[i]);
            trampoline_args[i] = &scalars[i];
        }
    }

    // The return value (if any) is always scalar.
    uint64_t ret_val = 0;
    const bool has_retval = !ret_type.is_void;
    internal_assert(!ret_type.is_buffer);
    if (has_retval) {
        trampoline_args.push_back(&ret_val);
    }
    (*trampoline)(trampoline_args.data());

    if (has_retval) {
        dynamic_type_dispatch<LoadAndReturnScalar>(ret_type.type, context, (void *)&ret_val, args.GetReturnValue());
    }

    // Progagate buffer data backwards. Note that for arbitrary extern functions,
    // we have no idea which buffers might be "input only", so we copy all data for all of them.
    for (size_t i = 0; i < arg_types_len; ++i) {
        if (arg_types[i].is_buffer) {
            const wasm32_ptr_t buf_ptr = args[i]->Int32Value(context).ToChecked();
            copy_hostbuf_to_existing_wasmbuf(context, buffers[i], buf_ptr);
        }
    }
}

bool should_skip_extern_symbol(const std::string &name) {
    static std::set<std::string> symbols = {
        "halide_print",
        "halide_error"};
    return symbols.count(name) > 0;
}

void add_extern_callbacks(const Local<Context> &context,
                          const JITExternMap &jit_externs,
                          const JITModule &trampolines,
                          Local<Object> &imports_dict) {
    Isolate *isolate = context->GetIsolate();
    Local<ObjectTemplate> extern_callback_template = ObjectTemplate::New(isolate);
    extern_callback_template->SetInternalFieldCount(4);
    for (const auto &it : jit_externs) {
        const auto &name = it.first;
        if (should_skip_extern_symbol(name)) {
            continue;
        }

        const auto &jit_extern = it.second;

        const auto &trampoline_symbol = trampolines.exports().find(name + kTrampolineSuffix);
        internal_assert(trampoline_symbol != trampolines.exports().end());

        const auto &sig = jit_extern.extern_c_function().signature();
        size_t arg_count = sig.arg_types().size();
        Local<ArrayBuffer> arg_types_wrap = ArrayBuffer::New(isolate, sizeof(ExternArgType) * (arg_count + 1));
        ExternArgType *arg_types = (ExternArgType *)arg_types_wrap->GetContents().Data();
        if (sig.is_void_return()) {
            // Type specified here will be ignored
            *arg_types++ = ExternArgType{{halide_type_int, 0, 0}, true, false};
        } else {
            const Type &t = sig.ret_type();
            const bool is_buffer = (t == type_of<halide_buffer_t *>());
            user_assert(t.lanes() == 1) << "Halide Extern functions cannot return vector values.";
            user_assert(!is_buffer) << "Halide Extern functions cannot return halide_buffer_t.";
            // TODO: the assertion below can be removed once we are able to marshal int64 values across the barrier
            user_assert(!(t.is_handle() && !is_buffer)) << "Halide Extern functions cannot return arbitrary pointers as arguments.";
            // TODO: the assertion below can be removed once we are able to marshal int64 values across the barrier
            user_assert(!(t.is_int_or_uint() && t.bits() == 64)) << "Halide Extern functions cannot accept 64-bit values as arguments.";
            *arg_types++ = ExternArgType{t, false, false};
        }
        for (size_t i = 0; i < arg_count; ++i) {
            const Type &t = sig.arg_types()[i];
            const bool is_buffer = (t == type_of<halide_buffer_t *>());
            user_assert(t.lanes() == 1) << "Halide Extern functions cannot accept vector values as arguments.";
            // TODO: the assertion below can be removed once we are able to marshal int64 values across the barrier
            user_assert(!(t.is_handle() && !is_buffer)) << "Halide Extern functions cannot accept arbitrary pointers as arguments.";
            // TODO: the assertion below can be removed once we are able to marshal int64 values across the barrier
            user_assert(!(t.is_int_or_uint() && t.bits() == 64)) << "Halide Extern functions cannot accept 64-bit values as arguments.";
            *arg_types++ = ExternArgType{t, false, is_buffer};
        }

        Local<Object> wrapper_data = extern_callback_template->NewInstance(context).ToLocalChecked();
        Local<External> trampoline_wrap(External::New(isolate, const_cast<void *>(trampoline_symbol->second.address)));
        wrapper_data->SetInternalField(kTrampolineWrap, trampoline_wrap);
        wrapper_data->SetInternalField(kArgTypesWrap, arg_types_wrap);

        Local<v8::String> key = NewLocalString(isolate, name.c_str());
        Local<v8::Function> value = FunctionTemplate::New(isolate, v8_extern_wrapper, wrapper_data)
                                        ->GetFunction(context)
                                        .ToLocalChecked();

        (void)imports_dict->Set(context, key, value).ToChecked();
    }
}

std::vector<char> compile_to_wasm(const Module &module, const std::string &fn_name) {
    static std::mutex link_lock;
    std::lock_guard<std::mutex> lock(link_lock);

    llvm::LLVMContext context;
    std::unique_ptr<llvm::Module> fn_module;

    // Default wasm stack size is ~64k, but schedules with lots of
    // alloca usage (heavily inlined, or tracing enabled) can blow thru
    // this, which crashes in amusing ways, so ask for extra stack space
    // for the alloca usage.
    size_t stack_size = 65536;
    {
        std::unique_ptr<CodeGen_WebAssembly> cg(new CodeGen_WebAssembly(module.target()));
        cg->set_context(context);
        fn_module = cg->compile(module);
        stack_size += cg->get_requested_alloca_total();
    }

    stack_size = align_up(stack_size);
    wdebug(0) << "Requesting stack size of " << stack_size << "\n";

    std::unique_ptr<llvm::Module> llvm_module =
        link_with_wasm_jit_runtime(&context, module.target(), std::move(fn_module));

    llvm::SmallVector<char, 4096> object;
    llvm::raw_svector_ostream object_stream(object);
    compile_llvm_module_to_object(*llvm_module, object_stream);

    // TODO: surely there's a better way that doesn't require spooling things
    // out to temp files
    TemporaryFile obj_file("", ".o");
    write_entire_file(obj_file.pathname(), object.data(), object.size());
#if WASM_DEBUG_LEVEL
    obj_file.detach();
    wdebug(0) << "Dumping obj_file to " << obj_file.pathname() << "\n";
#endif

    TemporaryFile wasm_output("", ".wasm");

    std::string lld_arg_strs[] = {
        "HalideJITLinker",
        // For debugging purposes:
        // "--verbose",
        // "-error-limit=0",
        // "--print-gc-sections",
        "--export=__data_end",
        "--export=__heap_base",
        "--allow-undefined",
        "-zstack-size=" + std::to_string(stack_size),
        obj_file.pathname(),
        "--entry=" + fn_name,
        "-o",
        wasm_output.pathname()};

    constexpr int c = sizeof(lld_arg_strs) / sizeof(lld_arg_strs[0]);
    const char *lld_args[c];
    for (int i = 0; i < c; ++i)
        lld_args[i] = lld_arg_strs[i].c_str();

#if LLVM_VERSION >= 110
    if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, llvm::outs(), llvm::errs())) {
        internal_error << "lld::wasm::link failed\n";
    }
#elif LLVM_VERSION >= 100
    std::string lld_errs_string;
    llvm::raw_string_ostream lld_errs(lld_errs_string);

    if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, llvm::outs(), llvm::errs())) {
        internal_error << "lld::wasm::link failed: (" << lld_errs.str() << ")\n";
    }
#else
    std::string lld_errs_string;
    llvm::raw_string_ostream lld_errs(lld_errs_string);

    if (!lld::wasm::link(lld_args, /*CanExitEarly*/ false, lld_errs)) {
        internal_error << "lld::wasm::link failed: (" << lld_errs.str() << ")\n";
    }
#endif

#if WASM_DEBUG_LEVEL
    wasm_output.detach();
    wdebug(0) << "Dumping linked wasm to " << wasm_output.pathname() << "\n";
#endif

    return read_entire_file(wasm_output.pathname());
}

}  // namespace
}  // namespace Internal
}  // namespace Halide

#endif  // WITH_V8

namespace Halide {
namespace Internal {

struct WasmModuleContents {
    mutable RefCount ref_count;

    const Target target;
    const std::vector<Argument> arguments;
    JITExternMap jit_externs;
    std::vector<JITModule> extern_deps;
    JITModule trampolines;
    BDMalloc bdmalloc;

#ifdef WITH_V8
    v8::Isolate *isolate = nullptr;
    v8::ArrayBuffer::Allocator *array_buffer_allocator = nullptr;
    v8::Persistent<v8::Context> v8_context;
    v8::Persistent<v8::Function> v8_function;
#endif

    WasmModuleContents(
        const Module &module,
        const std::vector<Argument> &arguments,
        const std::string &fn_name,
        const JITExternMap &jit_externs,
        const std::vector<JITModule> &extern_deps);

    int run(const void **args);

    ~WasmModuleContents();
};

WasmModuleContents::WasmModuleContents(
    const Module &module,
    const std::vector<Argument> &arguments,
    const std::string &fn_name,
    const JITExternMap &jit_externs,
    const std::vector<JITModule> &extern_deps)
    : target(module.target()),
      arguments(arguments),
      jit_externs(jit_externs),
      extern_deps(extern_deps),
      trampolines(JITModule::make_trampolines_module(get_host_target(), jit_externs, kTrampolineSuffix, extern_deps)) {

    wdebug(0) << "Compiling wasm function " << fn_name << "\n";

#ifdef WITH_V8
    static std::once_flag init_v8_once;
    std::call_once(init_v8_once, []() {
        // Initialize V8.
        V8::InitializeICU();
        static std::unique_ptr<Platform> platform = platform::NewDefaultPlatform();
        V8::InitializePlatform(platform.get());
        V8::Initialize();
        std::vector<std::string> flags = {
            // TODO: these need to match the flags we set in CodeGen_WebAssembly::mattrs().
            // Note that we currently enable all features that *might* be used
            // (eg we enable simd even though we might not use it) as we may well end
            // using different Halide Targets across our lifespan.
            "--experimental-wasm-sat-f2i-conversions",  // +nontrapping-fptoint
            "--experimental_wasm_se",                   // +sign-ext
            "--experimental_wasm_simd",                 // +simd128
            // "--experimental_wasm_bulk_memory",

            // Sometimes useful for debugging purposes:
            // "--print_all_exceptions=true",
            // "--abort_on_uncaught_exception",
            // "--trace-ignition-codegen",
            // "--trace_wasm_decoder",
            // "--no-liftoff",
            // "--wasm-interpret-all",
            // "--trace-wasm-memory",
        };
        for (const auto &f : flags) {
            V8::SetFlagsFromString(f.c_str(), f.size());
        }
    });

    array_buffer_allocator = v8::ArrayBuffer::Allocator::NewDefaultAllocator();

    Isolate::CreateParams isolate_params;
    isolate_params.snapshot_blob = nullptr;
    isolate_params.array_buffer_allocator = array_buffer_allocator;
    // Create a new Isolate and make it the current one.
    isolate = Isolate::New(isolate_params);

    Locker locker(isolate);
    Isolate::Scope isolate_scope(isolate);

    // Create a stack-allocated handle scope.
    HandleScope handle_scope(isolate);

    Local<ObjectTemplate> global = ObjectTemplate::New(isolate);
    Local<Context> context = Context::New(isolate, NULL, global);
    v8_context.Reset(isolate, context);

    Context::Scope context_scope(context);

    TryCatch try_catch(isolate);
    try_catch.SetCaptureMessage(true);
    try_catch.SetVerbose(true);

    Local<v8::String> fn_name_str = NewLocalString(isolate, fn_name.c_str());

    std::vector<char> final_wasm = compile_to_wasm(module, fn_name);

    MaybeLocal<WasmModuleObject> maybe_compiled = WasmModuleObject::DeserializeOrCompile(isolate,
                                                                                         /* serialized_module */ {nullptr, 0},
                                                                                         /* wire_bytes */ {(const uint8_t *)final_wasm.data(), final_wasm.size()});

    Local<WasmModuleObject> compiled;
    if (!maybe_compiled.ToLocal(&compiled)) {
        // Versions of V8 prior to 7.5 or so don't propagate the exception properly,
        // so don't attempt to print the exception info if it's not present.
        if (try_catch.HasCaught()) {
            String::Utf8Value error(isolate, try_catch.Exception());
            internal_error << "Error compiling wasm: " << *error << "\n";
        } else {
            internal_error << "Error compiling wasm: <unknown>\n";
        }
    }

    Local<Object> imports_dict = Object::New(isolate);

    const auto add_callback = [&](const char *name, FunctionCallback f) {
        // Skip any leading :: nonsense that we needed to add
        // to disambiguate (say) ::sin() from Halide::sin()
        while (*name == ':')
            name++;
        Local<v8::String> key = NewLocalString(isolate, name);
        Local<v8::Function> value = FunctionTemplate::New(isolate, f)->GetFunction(context).ToLocalChecked();
        (void)imports_dict->Set(context, key, value).ToChecked();
    };

#define ADD_CALLBACK(x) add_callback(#x, wasm_jit_##x##_callback);

    // Halide Runtime glue
    ADD_CALLBACK(halide_error);
    ADD_CALLBACK(halide_print);
    ADD_CALLBACK(halide_trace_helper);

    // libc-ish glue
    ADD_CALLBACK(__cxa_atexit)
    ADD_CALLBACK(abort)
    ADD_CALLBACK(fclose)
    ADD_CALLBACK(fileno)
    ADD_CALLBACK(fopen)
    ADD_CALLBACK(free)
    ADD_CALLBACK(fwrite)
    ADD_CALLBACK(getenv)
    ADD_CALLBACK(malloc)
    ADD_CALLBACK(memcmp)
    ADD_CALLBACK(memcpy)
    ADD_CALLBACK(memset)
    ADD_CALLBACK(strlen)
    ADD_CALLBACK(write)
    ADD_CALLBACK(__extendhfsf2)
    ADD_CALLBACK(__truncsfhf2)

#undef ADD_CALLBACK

#define ADD_POSIX_MATH(t, f) add_callback(#f, wasm_jit_posix_math_callback<t, f>);
#define ADD_POSIX_MATH2(t, f) add_callback(#f, wasm_jit_posix_math2_callback<t, f>);

    // math glue
    ADD_POSIX_MATH(double, ::acos)
    ADD_POSIX_MATH(double, ::acosh)
    ADD_POSIX_MATH(double, ::asin)
    ADD_POSIX_MATH(double, ::asinh)
    ADD_POSIX_MATH(double, ::atan)
    ADD_POSIX_MATH(double, ::atanh)
    ADD_POSIX_MATH(double, ::cos)
    ADD_POSIX_MATH(double, ::cosh)
    ADD_POSIX_MATH(double, ::exp)
    ADD_POSIX_MATH(double, ::log)
    ADD_POSIX_MATH(double, ::round)
    ADD_POSIX_MATH(double, ::sin)
    ADD_POSIX_MATH(double, ::sinh)
    ADD_POSIX_MATH(double, ::tan)
    ADD_POSIX_MATH(double, ::tanh)

    ADD_POSIX_MATH(float, ::acosf)
    ADD_POSIX_MATH(float, ::acoshf)
    ADD_POSIX_MATH(float, ::asinf)
    ADD_POSIX_MATH(float, ::asinhf)
    ADD_POSIX_MATH(float, ::atanf)
    ADD_POSIX_MATH(float, ::atanhf)
    ADD_POSIX_MATH(float, ::cosf)
    ADD_POSIX_MATH(float, ::coshf)
    ADD_POSIX_MATH(float, ::expf)
    ADD_POSIX_MATH(float, ::logf)
    ADD_POSIX_MATH(float, ::roundf)
    ADD_POSIX_MATH(float, ::sinf)
    ADD_POSIX_MATH(float, ::sinhf)
    ADD_POSIX_MATH(float, ::tanf)
    ADD_POSIX_MATH(float, ::tanhf)

    ADD_POSIX_MATH2(float, ::atan2f)
    ADD_POSIX_MATH2(double, ::atan2)
    ADD_POSIX_MATH2(float, ::powf)
    ADD_POSIX_MATH2(double, ::pow)

#undef ADD_POSIX_MATH
#undef ADD_POSIX_MATH2

    add_extern_callbacks(context, jit_externs, trampolines, imports_dict);

    Local<Object> imports = Object::New(isolate);
    (void)imports->Set(context, NewLocalString(isolate, "env"), imports_dict).ToChecked();

    Local<Value> instance_args[2] = {compiled, imports};

    Local<Object> exports = context->Global()
                                ->Get(context, NewLocalString(isolate, "WebAssembly"))
                                .ToLocalChecked()
                                .As<Object>()
                                ->Get(context, NewLocalString(isolate, "Instance"))
                                .ToLocalChecked()
                                .As<Object>()
                                ->CallAsConstructor(context, 2, instance_args)
                                .ToLocalChecked()
                                .As<Object>()
                                ->Get(context, NewLocalString(isolate, "exports"))
                                .ToLocalChecked()
                                .As<Object>();

    Local<Value> function_value = exports->Get(context, fn_name_str).ToLocalChecked();
    Local<v8::Function> function = Local<v8::Function>::Cast(function_value);
    internal_assert(!function.IsEmpty());
    internal_assert(!function->IsNullOrUndefined());
    v8_function.Reset(isolate, function);

    context->SetEmbedderData(kWasmMemoryObject, exports->Get(context, NewLocalString(isolate, "memory")).ToLocalChecked().As<Object>());
    context->SetAlignedPointerInEmbedderData(kBDMallocPtr, &bdmalloc);
    context->SetEmbedderData(kHeapBase, exports->Get(context, NewLocalString(isolate, "__heap_base")).ToLocalChecked().As<Object>());
    context->SetEmbedderData(kString_buffer, NewLocalString(isolate, "buffer"));
    context->SetEmbedderData(kString_grow, NewLocalString(isolate, "grow"));

    internal_assert(!try_catch.HasCaught());
#endif
}

int WasmModuleContents::run(const void **args) {
#ifdef WITH_V8
    Locker locker(isolate);
    Isolate::Scope isolate_scope(isolate);

    // Create a stack-allocated handle scope.
    HandleScope handle_scope(isolate);

    Local<Context> context = Local<Context>::New(isolate, v8_context);
    // Enter the context for compiling and running the hello world script.
    Context::Scope context_scope(context);

    TryCatch try_catch(isolate);
    try_catch.SetCaptureMessage(true);
    try_catch.SetVerbose(true);

    std::vector<wasm32_ptr_t> wbufs(arguments.size(), 0);

    std::vector<v8::Local<Value>> js_args;
    for (size_t i = 0; i < arguments.size(); i++) {
        const Argument &arg = arguments[i];
        const void *arg_ptr = args[i];
        if (arg.is_buffer()) {
            halide_buffer_t *buf = (halide_buffer_t *)const_cast<void *>(arg_ptr);
            internal_assert(buf);
            wasm32_ptr_t wbuf = hostbuf_to_wasmbuf(context, buf);
            wbufs[i] = wbuf;
            js_args.push_back(wrap_scalar(context, wbuf));
        } else {
            if (arg.name == "__user_context") {
                js_args.push_back(wrap_scalar(context, kMagicJitUserContextValue));
                JITUserContext *jit_user_context = check_jit_user_context(*(JITUserContext **)const_cast<void *>(arg_ptr));
                context->SetAlignedPointerInEmbedderData(kJitUserContext, jit_user_context);
            } else {
                js_args.push_back(wrap_scalar(context, arg.type, arg_ptr));
            }
        }
    }

    Local<v8::Function> function = Local<v8::Function>::New(isolate, v8_function);
    MaybeLocal<Value> result = function->Call(context, context->Global(), js_args.size(), js_args.data());

    if (result.IsEmpty()) {
        String::Utf8Value error(isolate, try_catch.Exception());
        String::Utf8Value message(isolate, try_catch.Message()->GetSourceLine(context).ToLocalChecked());

        internal_error << "Error running wasm: " << *error << " | Line: " << *message << "\n";
    }

    int r = result.ToLocalChecked()->Int32Value(context).ToChecked();
    if (r == 0) {
        // Update any output buffers
        for (size_t i = 0; i < arguments.size(); i++) {
            const Argument &arg = arguments[i];
            const void *arg_ptr = args[i];
            if (arg.is_buffer()) {
                halide_buffer_t *buf = (halide_buffer_t *)const_cast<void *>(arg_ptr);
                copy_wasmbuf_to_existing_hostbuf(context, wbufs[i], buf);
            }
        }
    }

    for (wasm32_ptr_t p : wbufs) {
        v8_WasmMemoryObject_free(context, p);
    }

    // Don't do this: things allocated by Halide runtime might need to persist
    // between multiple invocations of the same function.
    // bdmalloc.reset();

    return r;
#endif

    internal_error;
    return -1;
}

WasmModuleContents::~WasmModuleContents() {
#ifdef WITH_V8
    if (isolate != nullptr) {
        // TODO: Do we have to do this explicitly, or does disposing the Isolate handle it?
        {
            v8::Locker locker(isolate);
            v8::Isolate::Scope isolate_scope(isolate);

            v8_context.Reset();
            v8_function.Reset();
        }

        isolate->Dispose();
    }
    delete array_buffer_allocator;
#endif
}

template<>
RefCount &ref_count<WasmModuleContents>(const WasmModuleContents *p) noexcept {
    return p->ref_count;
}

template<>
void destroy<WasmModuleContents>(const WasmModuleContents *p) {
    delete p;
}

/*static*/
bool WasmModule::can_jit_target(const Target &target) {
#if defined(WITH_V8) || defined(WITH_SPIDERMONKEY)
    if (target.arch == Target::WebAssembly) {
        return true;
    }
#endif
    return false;
}

/*static*/
WasmModule WasmModule::compile(
    const Module &module,
    const std::vector<Argument> &arguments,
    const std::string &fn_name,
    const JITExternMap &jit_externs,
    const std::vector<JITModule> &extern_deps) {
#if !defined(WITH_V8) && !defined(WITH_SPIDERMONKEY)
    user_error << "Cannot run JITted JavaScript without configuring a JavaScript engine.";
    return WasmModule();
#endif

    WasmModule wasm_module;
    wasm_module.contents = new WasmModuleContents(module, arguments, fn_name, jit_externs, extern_deps);
    return wasm_module;
}

/** Run generated previously compiled wasm code with a set of arguments. */
int WasmModule::run(const void **args) {
    internal_assert(contents.defined());
    return contents->run(args);
}

}  // namespace Internal
}  // namespace Halide
back to top