Revision 16ddff55efc02d37b713eed569d435bdc4f5dfb7 authored by Andrew Adams on 31 August 2023, 22:21:03 UTC, committed by Andrew Adams on 31 August 2023, 22:21:03 UTC
1 parent ef9a7d8
memoize.cpp
#include "Halide.h"
#include "HalideRuntime.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
using namespace Halide;
// External functions to track whether the cache is working.
int call_count = 0;
extern "C" HALIDE_EXPORT_SYMBOL int count_calls(halide_buffer_t *out) {
if (!out->is_bounds_query()) {
call_count++;
Halide::Runtime::Buffer<uint8_t>(*out).fill(42);
}
return 0;
}
int call_count_with_arg = 0;
extern "C" HALIDE_EXPORT_SYMBOL int count_calls_with_arg(uint8_t val, halide_buffer_t *out) {
if (!out->is_bounds_query()) {
call_count_with_arg++;
Halide::Runtime::Buffer<uint8_t>(*out).fill(val);
}
return 0;
}
int call_count_with_arg_parallel[8];
extern "C" HALIDE_EXPORT_SYMBOL int count_calls_with_arg_parallel(uint8_t val, halide_buffer_t *out) {
if (!out->is_bounds_query()) {
call_count_with_arg_parallel[out->dim[2].min]++;
Halide::Runtime::Buffer<uint8_t>(*out).fill(val);
}
return 0;
}
int call_count_staged[4];
extern "C" HALIDE_EXPORT_SYMBOL int count_calls_staged(int32_t stage, uint8_t val, halide_buffer_t *in, halide_buffer_t *out) {
if (in->is_bounds_query()) {
for (int i = 0; i < out->dimensions; i++) {
in->dim[i] = out->dim[i];
}
} else if (!out->is_bounds_query()) {
assert(stage < static_cast<int32_t>(sizeof(call_count_staged) / sizeof(call_count_staged[0])));
call_count_staged[stage]++;
Halide::Runtime::Buffer<uint8_t> out_buf(*out), in_buf(*in);
out_buf.for_each_value([&](uint8_t &out, uint8_t &in) { out = in + val; }, in_buf);
}
return 0;
}
extern "C" HALIDE_EXPORT_SYMBOL int computed_eviction_key(int a) {
return 2020 + a;
}
HalideExtern_1(int, computed_eviction_key, int);
void *(*default_malloc)(JITUserContext *, size_t);
void (*default_free)(JITUserContext *, void *);
// A flaky allocator that wraps the built-in runtime one.
void *flaky_malloc(JITUserContext *user_context, size_t x) {
if ((rand() % 4) == 0) {
return nullptr;
} else {
return default_malloc(user_context, x);
}
}
void simple_free(JITUserContext *user_context, void *ptr) {
return default_free(user_context, ptr);
}
bool error_occured = false;
void record_error(JITUserContext *user_context, const char *msg) {
error_occured = true;
}
int main(int argc, char **argv) {
{
call_count = 0;
Func count_calls;
count_calls.define_extern("count_calls", {}, UInt(8), 2);
Func f, f_memoized;
f_memoized() = count_calls(0, 0);
f() = f_memoized();
f_memoized.compute_root().memoize();
Buffer<uint8_t> result1 = f.realize();
Buffer<uint8_t> result2 = f.realize();
assert(result1(0) == 42);
assert(result2(0) == 42);
assert(call_count == 1);
}
{
call_count = 0;
Param<int32_t> coord;
Func count_calls;
count_calls.define_extern("count_calls", {}, UInt(8), 2);
Func f, g;
Var x, y;
f() = count_calls(coord, coord);
f.compute_root().memoize();
g(x, y) = f();
coord.set(0);
Buffer<uint8_t> out1 = g.realize({256, 256});
Buffer<uint8_t> out2 = g.realize({256, 256});
for (int32_t i = 0; i < 256; i++) {
for (int32_t j = 0; j < 256; j++) {
assert(out1(i, j) == 42);
assert(out2(i, j) == 42);
}
}
assert(call_count == 1);
coord.set(1);
Buffer<uint8_t> out3 = g.realize({256, 256});
Buffer<uint8_t> out4 = g.realize({256, 256});
for (int32_t i = 0; i < 256; i++) {
for (int32_t j = 0; j < 256; j++) {
assert(out3(i, j) == 42);
assert(out4(i, j) == 42);
}
}
assert(call_count == 2);
}
{
call_count = 0;
Func count_calls;
count_calls.define_extern("count_calls", {}, UInt(8), 2);
Func f;
Var x, y;
f(x, y) = count_calls(x, y) + count_calls(x, y);
count_calls.compute_root().memoize();
Buffer<uint8_t> out1 = f.realize({256, 256});
Buffer<uint8_t> out2 = f.realize({256, 256});
for (int32_t i = 0; i < 256; i++) {
for (int32_t j = 0; j < 256; j++) {
assert(out1(i, j) == (42 + 42));
assert(out2(i, j) == (42 + 42));
}
}
assert(call_count == 1);
}
call_count = 0;
{
Func count_calls_23;
count_calls_23.define_extern("count_calls_with_arg", {cast<uint8_t>(23)}, UInt(8), 2);
Func count_calls_42;
count_calls_42.define_extern("count_calls_with_arg", {cast<uint8_t>(42)}, UInt(8), 2);
Func f;
Var x, y;
f(x, y) = count_calls_23(x, y) + count_calls_42(x, y);
count_calls_23.compute_root().memoize();
count_calls_42.compute_root().memoize();
Buffer<uint8_t> out1 = f.realize({256, 256});
Buffer<uint8_t> out2 = f.realize({256, 256});
for (int32_t i = 0; i < 256; i++) {
for (int32_t j = 0; j < 256; j++) {
assert(out1(i, j) == (23 + 42));
assert(out2(i, j) == (23 + 42));
}
}
assert(call_count_with_arg == 2);
}
{
Param<uint8_t> val1;
Param<uint8_t> val2;
call_count_with_arg = 0;
Func count_calls_val1;
count_calls_val1.define_extern("count_calls_with_arg", {val1}, UInt(8), 2);
Func count_calls_val2;
count_calls_val2.define_extern("count_calls_with_arg", {val2}, UInt(8), 2);
Func f;
Var x, y;
f(x, y) = count_calls_val1(x, y) + count_calls_val2(x, y);
count_calls_val1.compute_root().memoize();
count_calls_val2.compute_root().memoize();
val1.set(23);
val2.set(42);
Buffer<uint8_t> out1 = f.realize({256, 256});
Buffer<uint8_t> out2 = f.realize({256, 256});
val1.set(42);
Buffer<uint8_t> out3 = f.realize({256, 256});
val1.set(23);
Buffer<uint8_t> out4 = f.realize({256, 256});
val1.set(42);
Buffer<uint8_t> out5 = f.realize({256, 256});
val2.set(57);
Buffer<uint8_t> out6 = f.realize({256, 256});
for (int32_t i = 0; i < 256; i++) {
for (int32_t j = 0; j < 256; j++) {
assert(out1(i, j) == (23 + 42));
assert(out2(i, j) == (23 + 42));
assert(out3(i, j) == (42 + 42));
assert(out4(i, j) == (23 + 42));
assert(out5(i, j) == (42 + 42));
assert(out6(i, j) == (42 + 57));
}
}
assert(call_count_with_arg == 4);
}
{
Param<float> val;
call_count_with_arg = 0;
Func count_calls;
count_calls.define_extern("count_calls_with_arg", {cast<uint8_t>(val)}, UInt(8), 2);
Func f;
Var x, y;
f(x, y) = count_calls(x, y) + count_calls(x, y);
count_calls.compute_root().memoize();
val.set(23.0f);
Buffer<uint8_t> out1 = f.realize({256, 256});
val.set(23.4f);
Buffer<uint8_t> out2 = f.realize({256, 256});
for (int32_t i = 0; i < 256; i++) {
for (int32_t j = 0; j < 256; j++) {
assert(out1(i, j) == (23 + 23));
assert(out2(i, j) == (23 + 23));
}
}
assert(call_count_with_arg == 2);
}
{
Param<float> val;
call_count_with_arg = 0;
Func count_calls;
count_calls.define_extern("count_calls_with_arg", {memoize_tag(cast<uint8_t>(val))}, UInt(8), 2);
Func f;
Var x, y;
f(x, y) = count_calls(x, y) + count_calls(x, y);
count_calls.compute_root().memoize();
val.set(23.0f);
Buffer<uint8_t> out1 = f.realize({256, 256});
val.set(23.4f);
Buffer<uint8_t> out2 = f.realize({256, 256});
for (int32_t i = 0; i < 256; i++) {
for (int32_t j = 0; j < 256; j++) {
assert(out1(i, j) == (23 + 23));
assert(out2(i, j) == (23 + 23));
}
}
assert(call_count_with_arg == 1);
}
{
// Case with bounds computed not equal to bounds realized.
Param<float> val;
Param<int32_t> index;
call_count_with_arg = 0;
Func count_calls;
count_calls.define_extern("count_calls_with_arg", {cast<uint8_t>(val)}, UInt(8), 2);
Func f, g, h;
Var x;
f(x) = count_calls(x, 0) + cast<uint8_t>(x);
g(x) = f(x);
h(x) = g(4) + g(index);
f.compute_root().memoize();
g.vectorize(x, 8).compute_at(h, x);
val.set(23.0f);
index.set(2);
Buffer<uint8_t> out1 = h.realize({1});
assert(out1(0) == (uint8_t)(2 * 23 + 4 + 2));
assert(call_count_with_arg == 3);
index.set(4);
out1 = h.realize({1});
assert(out1(0) == (uint8_t)(2 * 23 + 4 + 4));
assert(call_count_with_arg == 4);
}
{
// Test Tuple case
Param<float> val;
call_count_with_arg = 0;
Func count_calls;
count_calls.define_extern("count_calls_with_arg", {cast<uint8_t>(val)}, UInt(8), 2);
Func f;
Var x, y, xi, yi;
f(x, y) = Tuple(count_calls(x, y) + cast<uint8_t>(x), x);
count_calls.compute_root().memoize();
f.compute_root().memoize();
Func g;
g(x, y) = Tuple(f(x, y)[0] + f(x - 1, y)[0] + f(x + 1, y)[0], f(x, y)[1]);
val.set(23.0f);
Realization out = g.realize({128, 128});
Buffer<uint8_t> out0 = out[0];
Buffer<int32_t> out1 = out[1];
for (int32_t i = 0; i < 100; i++) {
for (int32_t j = 0; j < 100; j++) {
assert(out0(i, j) == (uint8_t)(3 * 23 + i + (i - 1) + (i + 1)));
assert(out1(i, j) == i);
}
}
out = g.realize({128, 128});
out0 = out[0];
out1 = out[1];
for (int32_t i = 0; i < 100; i++) {
for (int32_t j = 0; j < 100; j++) {
assert(out0(i, j) == (uint8_t)(3 * 23 + i + (i - 1) + (i + 1)));
assert(out1(i, j) == i);
}
}
assert(call_count_with_arg == 1);
}
{
// Test cache eviction
Param<float> val;
call_count_with_arg = 0;
Func count_calls;
count_calls.define_extern("count_calls_with_arg", {cast<uint8_t>(val)}, UInt(8), 2);
Func f;
Var x, y, xi, yi;
f(x, y) = count_calls(x, y) + cast<uint8_t>(x);
count_calls.compute_root().memoize();
Func g;
g(x, y) = f(x, y) + f(x - 1, y) + f(x + 1, y);
Internal::JITSharedRuntime::memoization_cache_set_size(1000000);
for (int v = 0; v < 1000; v++) {
int r = rand() % 256;
val.set((float)r);
Buffer<uint8_t> out1 = g.realize({128, 128});
for (int32_t i = 0; i < 100; i++) {
for (int32_t j = 0; j < 100; j++) {
assert(out1(i, j) == (uint8_t)(3 * r + i + (i - 1) + (i + 1)));
}
}
}
// TODO work out an assertion on call count here.
printf("Call count is %d.\n", call_count_with_arg);
// Return cache size to default.
Internal::JITSharedRuntime::memoization_cache_set_size(0);
}
{
// Test flushing entire cache with a single element larger than the cache
Param<float> val;
call_count_with_arg = 0;
Func count_calls;
count_calls.define_extern("count_calls_with_arg", {cast<uint8_t>(val)}, UInt(8), 2);
Func f;
Var x, y, xi, yi;
f(x, y) = count_calls(x, y) + cast<uint8_t>(x);
count_calls.compute_root().memoize();
Func g;
g(x, y) = f(x, y) + f(x - 1, y) + f(x + 1, y);
Internal::JITSharedRuntime::memoization_cache_set_size(1000000);
for (int v = 0; v < 1000; v++) {
int r = rand() % 256;
val.set((float)r);
Buffer<uint8_t> out1 = g.realize({128, 128});
for (int32_t i = 0; i < 100; i++) {
for (int32_t j = 0; j < 100; j++) {
assert(out1(i, j) == (uint8_t)(3 * r + i + (i - 1) + (i + 1)));
}
}
}
// TODO work out an assertion on call count here.
printf("Call count before oversize realize is %d.\n", call_count_with_arg);
call_count_with_arg = 0;
Buffer<uint8_t> big = g.realize({1024, 1024});
Buffer<uint8_t> big2 = g.realize({1024, 1024});
// TODO work out an assertion on call count here.
printf("Call count after oversize realize is %d.\n", call_count_with_arg);
call_count_with_arg = 0;
for (int v = 0; v < 1000; v++) {
int r = rand() % 256;
val.set((float)r);
Buffer<uint8_t> out1 = g.realize({128, 128});
for (int32_t i = 0; i < 100; i++) {
for (int32_t j = 0; j < 100; j++) {
assert(out1(i, j) == (uint8_t)(3 * r + i + (i - 1) + (i + 1)));
}
}
}
printf("Call count is %d.\n", call_count_with_arg);
// Return cache size to default.
Internal::JITSharedRuntime::memoization_cache_set_size(0);
}
{
// Test parallel cache access
Param<float> val;
Func count_calls;
count_calls.define_extern("count_calls_with_arg_parallel", {cast<uint8_t>(val)}, UInt(8), 3);
Func f;
Var x, y;
// Ensure that all calls map to the same cache key, but pass a thread ID
// through to avoid having to do locking or an atomic add
f(x, y) = count_calls(x, y % 4, memoize_tag(y / 16, 0)) + cast<uint8_t>(x);
Func g;
g(x, y) = f(x, y) + f(x - 1, y) + f(x + 1, y);
count_calls.compute_at(f, y).memoize();
f.compute_at(g, y).memoize();
g.parallel(y, 16);
val.set(23.0f);
Internal::JITSharedRuntime::memoization_cache_set_size(1000000);
Buffer<uint8_t> out = g.realize({128, 128});
for (int32_t i = 0; i < 128; i++) {
for (int32_t j = 0; j < 128; j++) {
assert(out(i, j) == (uint8_t)(3 * 23 + i + (i - 1) + (i + 1)));
}
}
// TODO work out an assertion on call counts here.
for (int i = 0; i < 8; i++) {
printf("Call count for thread %d is %d.\n", i, call_count_with_arg_parallel[i]);
}
// Return cache size to default.
Internal::JITSharedRuntime::memoization_cache_set_size(0);
}
{
// Test multiple argument memoize_tag. This can be unsafe but
// models cases where one uses a hash of image data as part of
// a tag to memoize an expensive computation.
ImageParam input(UInt(8), 1);
Param<int> key;
Func f, g;
RDom extent(input);
g() = memoize_tag(sum(input(extent)), key);
f() = g() + 42;
g.compute_root().memoize();
Buffer<uint8_t> in(10);
input.set(in);
in.fill(42);
key.set(0);
Buffer<uint8_t> result = f.realize();
assert(result() == (462 % 256));
// Change image data without channging tag
in(0) = 41;
result = f.realize();
// Result is likely stale. This is not strictly guaranteed due to e.g.
// cache size. Hence allow correct value to make test express the
// contract.
assert((result() == (462 % 256)) ||
(result() == (461 % 256)));
// Change tag, thus ensuring correct result.
key.set(1);
result = f.realize();
assert(result() == (461 % 256));
}
{
Param<float> val;
Func f;
Var x, y;
f(x, y) = cast<uint8_t>((x << 8) + y);
Func prev_func = f;
Func stage[4];
for (int i = 0; i < 4; i++) {
std::vector<ExternFuncArgument> args(3);
args[0] = cast<int32_t>(i);
args[1] = cast<int32_t>(val);
args[2] = prev_func;
stage[i].define_extern("count_calls_staged",
args,
UInt(8), 2);
prev_func = stage[i];
}
f.compute_root();
for (int i = 0; i < 3; i++) {
stage[i].compute_root();
}
stage[3].compute_root().memoize();
Func output;
output(_) = stage[3](_);
val.set(23.0f);
Buffer<uint8_t> result = output.realize({128, 128});
for (int32_t i = 0; i < 128; i++) {
for (int32_t j = 0; j < 128; j++) {
assert(result(i, j) == (uint8_t)((i << 8) + j + 4 * 23));
}
}
for (int i = 0; i < 4; i++) {
printf("Call count for stage %d is %d.\n", i, call_count_staged[i]);
}
result = output.realize({128, 128});
for (int32_t i = 0; i < 128; i++) {
for (int32_t j = 0; j < 128; j++) {
assert(result(i, j) == (uint8_t)((i << 8) + j + 4 * 23));
}
}
for (int i = 0; i < 4; i++) {
printf("Call count for stage %d is %d.\n", i, call_count_staged[i]);
}
}
if (get_jit_target_from_environment().arch == Target::WebAssembly) {
printf("[SKIP] WebAssembly JIT does not support custom allocators.\n");
return 0;
} else {
// Test out of memory handling.
// Get the runtime's malloc and free. We need to use the ones
// in the runtime to ensure the matching free is called when
// we release all the runtimes at the end.
JITUserContext ctx;
Internal::JITSharedRuntime::populate_jit_handlers(&ctx, JITHandlers{});
default_malloc = ctx.handlers.custom_malloc;
default_free = ctx.handlers.custom_free;
Param<float> val;
Func count_calls;
count_calls.define_extern("count_calls_with_arg", {cast<uint8_t>(val)}, UInt(8), 2);
Func f;
Var x, y, xi, yi;
f(x, y) = Tuple(count_calls(x, y) + cast<uint8_t>(x), x);
count_calls.compute_root().memoize();
f.compute_root().memoize();
Func g;
g(x, y) = Tuple(f(x, y)[0] + f(x - 1, y)[0] + f(x + 1, y)[0], f(x, y)[1]);
Pipeline pipe(g);
pipe.jit_handlers().custom_error = record_error;
pipe.jit_handlers().custom_malloc = flaky_malloc;
pipe.jit_handlers().custom_free = simple_free;
int total_errors = 0;
int completed = 0;
for (int trial = 0; trial < 100; trial++) {
call_count_with_arg = 0;
error_occured = false;
val.set(23.0f + trial);
Realization out = pipe.realize({16, 16});
if (error_occured) {
total_errors++;
} else {
Buffer<uint8_t> out0 = out[0];
Buffer<int32_t> out1 = out[1];
for (int32_t i = 0; i < 16; i++) {
for (int32_t j = 0; j < 16; j++) {
assert(out0(i, j) == (uint8_t)(3 * (23 + trial) + i + (i - 1) + (i + 1)));
assert(out1(i, j) == i);
}
}
error_occured = false;
out = pipe.realize({16, 16});
if (error_occured) {
total_errors++;
} else {
out0 = out[0];
out1 = out[1];
for (int32_t i = 0; i < 16; i++) {
for (int32_t j = 0; j < 16; j++) {
assert(out0(i, j) == (uint8_t)(3 * (23 + trial) + i + (i - 1) + (i + 1)));
assert(out1(i, j) == i);
}
}
assert(call_count_with_arg == 1);
completed++;
}
}
}
printf("In 100 attempts with flaky malloc, %d errors and %d full completions occured.\n",
total_errors, completed);
}
{
call_count = 0;
Func count_calls;
count_calls.define_extern("count_calls", {}, UInt(8), 2);
ImageParam input(UInt(8), 1);
Func f, f_memoized;
f_memoized() = count_calls(0, 0) + cast<uint8_t>(input.dim(0).extent());
f_memoized.compute_root().memoize();
f() = f_memoized();
Buffer<uint8_t> in_one(1);
input.set(in_one);
Buffer<uint8_t> result1 = f.realize();
Buffer<uint8_t> result2 = f.realize();
assert(result1(0) == 43);
assert(result2(0) == 43);
assert(call_count == 1);
Buffer<uint8_t> in_ten(10);
input.set(in_ten);
result1 = f.realize();
result2 = f.realize();
assert(result1(0) == 52);
assert(result2(0) == 52);
assert(call_count == 2);
}
// Test cache eviction.
{
call_count = 0;
Func count_calls;
count_calls.define_extern("count_calls", {}, UInt(8), 2);
Param<void *> p;
Func f, memoized_one, memoized_two, memoized_three;
memoized_one() = count_calls(0, 0);
memoized_two() = count_calls(1, 1);
memoized_three() = count_calls(3, 3);
memoized_one.compute_root().memoize(EvictionKey(1));
memoized_two.compute_root().memoize(EvictionKey(p));
// The called extern here would usually take user_context and extact a value
// from within, but JIT mostly subsumes user_context, so this is just an example.
memoized_three.compute_root().memoize(EvictionKey(computed_eviction_key(5)));
f() = memoized_one() + memoized_two() + memoized_three();
p.set((void *)&call_count);
Buffer<uint8_t> result1 = f.realize();
Buffer<uint8_t> result2 = f.realize();
assert(result1(0) == 126);
assert(result2(0) == 126);
assert(call_count == 3);
Internal::JITSharedRuntime::memoization_cache_evict(1);
result1 = f.realize();
assert(result1(0) == 126);
assert(call_count == 4);
Internal::JITSharedRuntime::memoization_cache_evict(1);
result1 = f.realize();
assert(result1(0) == 126);
assert(call_count == 5);
Internal::JITSharedRuntime::memoization_cache_evict(1);
Internal::JITSharedRuntime::memoization_cache_evict((uint64_t)(uintptr_t)&call_count);
result1 = f.realize();
assert(result1(0) == 126);
assert(call_count == 7);
Internal::JITSharedRuntime::memoization_cache_evict(2025);
result1 = f.realize();
assert(result1(0) == 126);
assert(call_count == 8);
}
Internal::JITSharedRuntime::release_all();
printf("Success!\n");
return 0;
}
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...