Revision 16ddff55efc02d37b713eed569d435bdc4f5dfb7 authored by Andrew Adams on 31 August 2023, 22:21:03 UTC, committed by Andrew Adams on 31 August 2023, 22:21:03 UTC
1 parent ef9a7d8
callable_generator.cpp
#include "Halide.h"
#include <iostream>
#include <stdio.h>
using namespace Halide;
namespace {
void check(int r) {
assert(r == 0);
}
bool custom_malloc_called = false;
bool custom_free_called = false;
void *my_malloc(JITUserContext *user_context, size_t x) {
custom_malloc_called = true;
void *orig = malloc(x + 32);
void *ptr = (void *)((((size_t)orig + 32) >> 5) << 5);
((void **)ptr)[-1] = orig;
return ptr;
}
void my_free(JITUserContext *user_context, void *ptr) {
custom_free_called = true;
free(((void **)ptr)[-1]);
}
int call_counter = 0;
extern "C" HALIDE_EXPORT_SYMBOL float my_extern_func(int x, float y) {
call_counter++;
return x * y;
}
HalideExtern_2(float, my_extern_func, int, float);
} // namespace
int main(int argc, char **argv) {
const Target t = get_jit_target_from_environment();
const GeneratorContext context(t);
{
class TestGen1 : public Generator<TestGen1> {
public:
Input<Buffer<uint8_t, 2>> img_{"img"};
Input<int32_t> int_{"int"};
Input<float> float_{"float"};
Output<Buffer<uint8_t, 2>> out_{"out"};
void generate() {
Var x("x"), y("y");
out_(x, y) = img_(x, y) + cast<uint8_t>(int_ / float_);
}
};
Buffer<uint8_t> in1(10, 10);
Buffer<uint8_t> in2(10, 10);
for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) {
in1(i, j) = i + j * 10;
in2(i, j) = i * 10 + j;
}
}
auto gen = TestGen1::create(context);
Callable c = gen->compile_to_callable();
Buffer<uint8_t> out1(10, 10);
check(c(in1, 42, 1.0f, out1));
Buffer<uint8_t> out2(10, 10);
check(c(in2, 22, 2.0f, out2));
Buffer<uint8_t> out3(10, 10);
check(c(in1, 12, 1.0f, out3));
Buffer<uint8_t> out4(10, 10);
check(c(in2, 16, 1.0f, out4));
for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) {
assert(out1(i, j) == i + j * 10 + 42);
assert(out2(i, j) == i * 10 + j + 11);
assert(out3(i, j) == i + j * 10 + 12);
assert(out4(i, j) == i * 10 + j + 16);
}
}
// Test bounds inference
Buffer<uint8_t> in_bounds(nullptr, 1, 1);
Buffer<uint8_t> out_bounds(nullptr, 20, 20);
check(c(in_bounds, 42, 1.0f, out_bounds));
assert(in_bounds.defined());
assert(in_bounds.dim(0).extent() == 20);
assert(in_bounds.dim(1).extent() == 20);
assert(in1.dim(0).extent() == 10);
assert(in1.dim(1).extent() == 10);
}
// Override Halide's malloc and free (except under wasm),
// make sure that Callable freezes the values
if (t.arch != Target::WebAssembly) {
custom_malloc_called = false;
custom_free_called = false;
class TestGen2 : public Generator<TestGen2> {
public:
Output<Buffer<int, 1>> out_{"out"};
void generate() {
Var x("x");
Func f;
f(x) = x;
out_(x) = f(x);
f.compute_root();
}
};
JITHandlers my_jit_handlers;
my_jit_handlers.custom_malloc = my_malloc;
my_jit_handlers.custom_free = my_free;
auto gen = TestGen2::create(context);
Callable c = gen->compile_to_callable(&my_jit_handlers);
Buffer<int> im(100000);
check(c(im));
assert(custom_malloc_called);
assert(custom_free_called);
}
// Check that Param<void*> works with Callables
if (t.arch != Target::WebAssembly) {
class TestGen3 : public Generator<TestGen3> {
public:
GeneratorParam<bool> vectorize_{"vectorize", false};
Input<void *> handle_{"handle"};
Output<Buffer<uint64_t, 1>> out_{"out"};
void generate() {
Var x("x");
out_(x) = reinterpret<uint64_t>(handle_);
if (vectorize_) {
out_.vectorize(x, 4);
}
}
};
auto gen_1 = TestGen3::create(context);
gen_1->vectorize_.set(false);
auto gen_2 = TestGen3::create(context);
gen_2->vectorize_.set(true);
Callable c1 = gen_1->compile_to_callable();
Callable c2 = gen_2->compile_to_callable();
int foo = 0;
Buffer<uint64_t> out1(4);
// Create a dummy JITUserContext here just to test that
// passing one explicitly works correctly.
JITUserContext empty;
check(c1(&empty, &foo, out1));
Buffer<uint64_t> out2(4);
check(c2(&foo, out2));
uint64_t correct = (uint64_t)((uintptr_t)(&foo));
for (int x = 0; x < out1.width(); x++) {
if (out1(x) != correct) {
printf("out1(%d) = %llu instead of %llu\n",
x,
(long long unsigned)out1(x),
(long long unsigned)correct);
exit(1);
}
if (out2(x) != correct) {
printf("out2(%d) = %llu instead of %llu\n",
x,
(long long unsigned)out2(x),
(long long unsigned)correct);
exit(1);
}
}
}
// Check that JITExterns works with Callables
if (t.arch != Target::WebAssembly) {
call_counter = 0;
class TestGen4 : public Generator<TestGen4> {
public:
Output<Buffer<float, 2>> out_{"out"};
void generate() {
Func f;
f.define_extern("extern_func", {user_context_value()}, Float(32), 2);
Var x("x"), y("y");
out_(x, y) = f(x, y);
}
};
Var x, y;
Func monitor;
monitor(x, y) = my_extern_func(x, cast<float>(y));
const std::map<std::string, JITExtern> my_jit_externs = {
{"extern_func", JITExtern{monitor}}};
auto gen = TestGen4::create(context);
Callable c = gen->compile_to_callable(nullptr, &my_jit_externs);
Buffer<float> imf(32, 32);
check(c(imf));
// Check the result was what we expected
for (int i = 0; i < 32; i++) {
for (int j = 0; j < 32; j++) {
float correct = (float)(i * j);
float delta = imf(i, j) - correct;
if (delta < -0.001 || delta > 0.001) {
printf("imf[%d, %d] = %f instead of %f\n", i, j, imf(i, j), correct);
exit(1);
}
}
}
if (call_counter != 32 * 32) {
printf("In pipeline_set_jit_externs_func, my_func was called %d times instead of %d\n", call_counter, 32 * 32);
exit(1);
}
}
printf("Success!\n");
}
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...