Revision 16ddff55efc02d37b713eed569d435bdc4f5dfb7 authored by Andrew Adams on 31 August 2023, 22:21:03 UTC, committed by Andrew Adams on 31 August 2023, 22:21:03 UTC
1 parent ef9a7d8
deferred_loop_level.cpp
#include "Halide.h"
using namespace Halide;
using namespace Halide::Internal;
class CheckLoopLevels : public IRVisitor {
public:
CheckLoopLevels(const std::string &inner_loop_level,
const std::string &outer_loop_level)
: inner_loop_level(inner_loop_level), outer_loop_level(outer_loop_level) {
}
private:
using IRVisitor::visit;
const std::string inner_loop_level, outer_loop_level;
std::string inside_for_loop;
void visit(const For *op) override {
std::string old_for_loop = inside_for_loop;
inside_for_loop = op->name;
IRVisitor::visit(op);
inside_for_loop = old_for_loop;
}
void visit(const Call *op) override {
IRVisitor::visit(op);
if (op->name == "sin_f32") {
_halide_user_assert(starts_with(inside_for_loop, inner_loop_level));
} else if (op->name == "cos_f32") {
_halide_user_assert(starts_with(inside_for_loop, outer_loop_level));
}
}
void visit(const Store *op) override {
IRVisitor::visit(op);
if (op->name.substr(0, 5) == "inner") {
_halide_user_assert(starts_with(inside_for_loop, inner_loop_level));
} else if (op->name.substr(0, 5) == "outer") {
_halide_user_assert(starts_with(inside_for_loop, outer_loop_level));
} else {
_halide_user_assert(0);
}
}
};
Var x("x"), y("y"), c("c");
struct Test {
Func inner, outer;
LoopLevel inner_compute_at, inner_store_at;
explicit Test(int i) {
// We use specific calls as proxies for verifying that compute_at
// happens where we expect: sin() for the inner function, cos()
// for the outer one; these are chosen mainly because they won't
// ever get generated incidentally by the lowering code as part of
// general code structure.
inner = Func("inner" + std::to_string(i));
inner(x, y, c) = sin(cast<float>(x + y + c));
inner.compute_at(inner_compute_at).store_at(inner_store_at);
outer = Func("outer" + std::to_string(i));
outer(x, y, c) = cos(cast<float>(inner(x, y, c)));
}
void check(const std::string &inner_loop_level,
const std::string &outer_loop_level) {
Buffer<float> result = outer.realize({1, 1, 1});
Module m = outer.compile_to_module({outer.infer_arguments()});
CheckLoopLevels c(inner_loop_level, outer_loop_level);
m.functions().front().body.accept(&c);
}
};
int main(int argc, char **argv) {
// Test that LoopLevels set after being specified still take effect.
{
Test t(1);
t.inner_compute_at.set(LoopLevel(t.outer, x));
t.inner_store_at.set(LoopLevel(t.outer, x));
t.check("outer1.s0.x", "outer1.s0.x");
}
// Same as before, but using inlined() for both inner LoopLevels.
{
Test t(2);
t.inner_compute_at.set(LoopLevel::inlined());
t.inner_store_at.set(LoopLevel::inlined());
t.check("outer2.s0.x", "outer2.s0.x");
}
// Same as before, but using root() for both inner LoopLevels.
{
Test t(3);
t.inner_compute_at.set(LoopLevel::root());
t.inner_store_at.set(LoopLevel::root());
t.check("inner3.s0.x", "outer3.s0.x");
}
// Same as before, but using different store_at and compute_at()
{
Test t(4);
t.inner_compute_at.set(LoopLevel(t.outer, y));
t.inner_store_at.set(LoopLevel::root());
t.check("inner4.s0.x", "outer4.s0.x");
}
// Same as before, but using inlined for store_at() [equivalent to omitting
// the store_at() call entirely] and non-inlined for compute_at
{
Test t(5);
t.inner_compute_at.set(LoopLevel(t.outer, y));
t.inner_store_at.set(LoopLevel::inlined());
t.check("inner5.s0.x", "outer5.s0.x");
}
printf("Success!\n");
return 0;
}
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...