Raw File
PyScheduleMethods.h
#ifndef HALIDE_PYTHON_BINDINGS_PYSCHEDULEMETHODS_H
#define HALIDE_PYTHON_BINDINGS_PYSCHEDULEMETHODS_H

#include "PyHalide.h"

namespace Halide {
namespace PythonBindings {

// Methods that are defined on both Func and Stage.
template <typename PythonClass>
HALIDE_NEVER_INLINE void add_schedule_methods(PythonClass &class_instance) {
    using T = typename PythonClass::type;

    class_instance

    .def("compute_with", (T &(T::*)(Stage, VarOrRVar, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &)) &T::compute_with,
        py::arg("stage"), py::arg("var"), py::arg("align"))
    .def("compute_with", (T &(T::*)(Stage, VarOrRVar, LoopAlignStrategy)) &T::compute_with,
        py::arg("stage"), py::arg("var"), py::arg("align") = LoopAlignStrategy::Auto)

    .def("unroll", (T &(T::*)(VarOrRVar)) &T::unroll,
        py::arg("var"))
    .def("unroll", (T &(T::*)(VarOrRVar, Expr, TailStrategy)) &T::unroll,
        py::arg("var"), py::arg("factor"), py::arg("tail") = TailStrategy::Auto)

    .def("split", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, Expr, TailStrategy)) &T::split,
        py::arg("old"), py::arg("outer"), py::arg("inner"), py::arg("factor"), py::arg("tail") = TailStrategy::Auto)

    .def("fuse", &T::fuse,
        py::arg("inner"), py::arg("outer"), py::arg("fused"))

    .def("serial", &T::serial,
        py::arg("var"))

    .def("tile", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, Expr, Expr, TailStrategy)) &T::tile,
        py::arg("x"), py::arg("y"), py::arg("xo"), py::arg("yo"), py::arg("xi"), py::arg("yi"), py::arg("xfactor"), py::arg("yfactor"), py::arg("tail") = TailStrategy::Auto)
    .def("tile", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, Expr, Expr, TailStrategy)) &T::tile,
        py::arg("x"), py::arg("y"), py::arg("xi"), py::arg("yi"), py::arg("xfactor"), py::arg("yfactor"), py::arg("tail") = TailStrategy::Auto)

    .def("reorder", (T &(T::*)(const std::vector<VarOrRVar> &)) &T::reorder, py::arg("vars"))
    .def("reorder", [](T &t, py::args args) -> T & {
        return t.reorder(args_to_vector<VarOrRVar>(args));
    })

    .def("parallel", (T &(T::*)(VarOrRVar)) &T::parallel,
        py::arg("var"))
    .def("parallel", (T &(T::*)(VarOrRVar, Expr, TailStrategy)) &T::parallel,
        py::arg("var"), py::arg("task_size"), py::arg("tail") = TailStrategy::Auto)

    .def("vectorize", (T &(T::*)(VarOrRVar)) &T::vectorize,
        py::arg("var"))
    .def("vectorize", (T &(T::*)(VarOrRVar, Expr, TailStrategy)) &T::vectorize,
        py::arg("var"), py::arg("factor"), py::arg("tail") = TailStrategy::Auto)

    .def("gpu_blocks", (T &(T::*)(VarOrRVar, DeviceAPI)) &T::gpu_blocks,
        py::arg("block_x"), py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu_blocks", (T &(T::*)(VarOrRVar, VarOrRVar, DeviceAPI)) &T::gpu_blocks,
        py::arg("block_x"), py::arg("block_y"), py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu_blocks", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, DeviceAPI)) &T::gpu_blocks,
        py::arg("block_x"), py::arg("block_y"), py::arg("block_z"), py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("gpu", (T &(T::*)(VarOrRVar, VarOrRVar, DeviceAPI)) &T::gpu,
        py::arg("block_x"), py::arg("thread_x"), py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, DeviceAPI)) &T::gpu,
        py::arg("block_x"), py::arg("block_y"), py::arg("thread_x"), py::arg("thread_y"), py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, DeviceAPI)) &T::gpu,
        py::arg("block_x"), py::arg("block_y"), py::arg("block_z"), py::arg("thread_x"), py::arg("thread_y"), py::arg("thread_z"), py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("gpu_threads", (T &(T::*)(VarOrRVar, DeviceAPI)) &T::gpu_threads,
        py::arg("thread_x"), py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu_threads", (T &(T::*)(VarOrRVar, VarOrRVar, DeviceAPI)) &T::gpu_threads,
        py::arg("thread_x"), py::arg("thread_y"), py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu_threads", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, DeviceAPI)) &T::gpu_threads,
        py::arg("thread_x"), py::arg("thread_y"), py::arg("thread_z"), py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu_single_thread", (T &(T::*)(DeviceAPI)) &T::gpu_single_thread,
        py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("gpu_tile", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, Expr, TailStrategy, DeviceAPI)) &T::gpu_tile,
        py::arg("x"), py::arg("bx"), py::arg("tx"), py::arg("x_size"),
        py::arg("tail") = TailStrategy::Auto, py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("gpu_tile", (T &(T::*)(VarOrRVar, VarOrRVar, Expr, TailStrategy, DeviceAPI)) &T::gpu_tile,
        py::arg("x"), py::arg("tx"), py::arg("x_size"),
        py::arg("tail") = TailStrategy::Auto, py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("gpu_tile", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, Expr, Expr, TailStrategy, DeviceAPI)) &T::gpu_tile,
        py::arg("x"), py::arg("y"), py::arg("bx"), py::arg("by"), py::arg("tx"), py::arg("ty"), py::arg("x_size"), py::arg("y_size"),
        py::arg("tail") = TailStrategy::Auto, py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("gpu_tile", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, Expr, Expr, TailStrategy, DeviceAPI)) &T::gpu_tile,
        py::arg("x"), py::arg("y"), py::arg("tx"), py::arg("ty"), py::arg("x_size"), py::arg("y_size"),
        py::arg("tail") = TailStrategy::Auto, py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("gpu_tile", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, Expr, Expr, Expr, TailStrategy, DeviceAPI)) &T::gpu_tile,
        py::arg("x"), py::arg("y"), py::arg("z"), py::arg("bx"), py::arg("by"), py::arg("bz"), py::arg("tx"), py::arg("ty"), py::arg("tz"), py::arg("x_size"), py::arg("y_size"), py::arg("z_size"),
        py::arg("tail") = TailStrategy::Auto, py::arg("device_api") = DeviceAPI::Default_GPU)
    .def("gpu_tile", (T &(T::*)(VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, VarOrRVar, Expr, Expr, Expr, TailStrategy, DeviceAPI)) &T::gpu_tile,
        py::arg("x"), py::arg("y"), py::arg("z"), py::arg("tx"), py::arg("ty"), py::arg("tz"), py::arg("x_size"), py::arg("y_size"), py::arg("z_size"),
        py::arg("tail") = TailStrategy::Auto, py::arg("device_api") = DeviceAPI::Default_GPU)

    .def("rename", &T::rename,
        py::arg("old_name"), py::arg("new_name"))

    .def("specialize", &T::specialize,
        py::arg("condition"))
    .def("specialize_fail", &T::specialize_fail,
        py::arg("message"))

    .def("allow_race_conditions", &T::allow_race_conditions)
    .def("hexagon", &T::hexagon, py::arg("x") = Var::outermost())

    .def("prefetch", (T &(T::*)(const Func &, VarOrRVar, Expr, PrefetchBoundStrategy)) &T::prefetch,
        py::arg("func"), py::arg("var"), py::arg("offset") = 1, py::arg("strategy") = PrefetchBoundStrategy::GuardWithIf)
    .def("prefetch", [](T &t, const ImageParam &image, VarOrRVar var, Expr offset, PrefetchBoundStrategy strategy) -> T & {
        // Templated function; specializing only on ImageParam for now
        return t.prefetch(image, var, offset, strategy);
    }, py::arg("image"), py::arg("var"), py::arg("offset") = 1, py::arg("strategy") = PrefetchBoundStrategy::GuardWithIf)

    .def("source_location", &T::source_location)
    ;
}

}  // namespace PythonBindings
}  // namespace Halide

#endif  // HALIDE_PYTHON_BINDINGS_PYSCHEDULEMETHODS_H
back to top