https://github.com/halide/Halide
Tip revision: fbb1d0a0efc37a5b553f2435d0a4f24f82a5e656 authored by Aly, Mohamed on 01 June 2018, 17:28:21 UTC
Changes to make it more concise
Changes to make it more concise
Tip revision: fbb1d0a
PyIROperator.cpp
#include "PyIROperator.h"
#include "PyTuple.h"
namespace Halide {
namespace PythonBindings {
namespace {
// TODO: clever template usage could generalize this to list-of-types-to-try.
std::vector<Expr> args_to_vector_for_print(const py::args &args, size_t start_offset = 0) {
if (args.size() < start_offset) {
throw py::value_error("Not enough arguments");
}
std::vector<Expr> v;
v.reserve(args.size() - (start_offset));
for (size_t i = start_offset; i < args.size(); ++i) {
// No way to see if a cast will work: just have to try
// and fail. Normally we don't want string to be convertible
// to Expr, but in this unusual case we do.
try {
v.push_back(args[i].cast<std::string>());
} catch (...) {
v.push_back(args[i].cast<Expr>());
}
}
return v;
}
} // namespace
void define_operators(py::module &m) {
m.def("max", [](py::args args) -> Expr {
if (args.size() < 2) {
throw py::value_error("max() must have at least 2 arguments");
}
int pos = (int) args.size() - 1;
Expr value = args[pos--].cast<Expr>();
while (pos >= 0) {
value = max(args[pos--].cast<Expr>(), value);
}
return value;
});
m.def("min", [](py::args args) -> Expr {
if (args.size() < 2) {
throw py::value_error("min() must have at least 2 arguments");
}
int pos = (int) args.size() - 1;
Expr value = args[pos--].cast<Expr>();
while (pos >= 0) {
value = min(args[pos--].cast<Expr>(), value);
}
return value;
});
m.def("clamp", &clamp);
m.def("abs", &abs);
m.def("absd", &absd);
m.def("select", [](py::args args) -> Expr {
if (args.size() < 3) {
throw py::value_error("select() must have at least 3 arguments");
}
if ((args.size() % 2) == 0) {
throw py::value_error("select() must have an odd number of arguments");
}
int pos = (int) args.size() - 1;
Expr false_value = args[pos--].cast<Expr>();
while (pos > 0) {
Expr true_value = args[pos--].cast<Expr>();
Expr condition = args[pos--].cast<Expr>();
false_value = select(condition, true_value, false_value);
}
return false_value;
});
m.def("tuple_select", [](py::args args) -> py::tuple {
_halide_user_assert(args.size() >= 3)
<< "tuple_select() must have at least 3 arguments";
_halide_user_assert((args.size() % 2) != 0)
<< "tuple_select() must have an odd number of arguments";
int pos = (int) args.size() - 1;
Tuple false_value = args[pos--].cast<Tuple>();
bool has_tuple_cond = false, has_expr_cond = false;
while (pos > 0) {
Tuple true_value = args[pos--].cast<Tuple>();;
// Note that 'condition' can be either Expr or Tuple, but must be consistent across all
py::object py_cond = args[pos--];
Expr expr_cond;
Tuple tuple_cond(expr_cond);
try {
tuple_cond = py_cond.cast<Tuple>();
has_tuple_cond = true;
} catch (...) {
expr_cond = py_cond.cast<Expr>();
has_expr_cond = true;
}
if (expr_cond.defined()) {
false_value = tuple_select(expr_cond, true_value, false_value);
} else {
false_value = tuple_select(tuple_cond, true_value, false_value);
}
}
_halide_user_assert(!(has_tuple_cond && has_expr_cond))
<<"tuple_select() may not mix Expr and Tuple for the condition elements.";
return to_python_tuple(false_value);
});
m.def("sin", &sin);
m.def("asin", &asin);
m.def("cos", &cos);
m.def("acos", &acos);
m.def("tan", &tan);
m.def("atan", &atan);
m.def("atan", &atan2);
m.def("atan2", &atan2);
m.def("sinh", &sinh);
m.def("asinh", &asinh);
m.def("cosh", &cosh);
m.def("acosh", &acosh);
m.def("tanh", &tanh);
m.def("atanh", &atanh);
m.def("sqrt", &sqrt);
m.def("hypot", &hypot);
m.def("exp", &exp);
m.def("log", &log);
m.def("pow", &pow);
m.def("erf", &erf);
m.def("fast_log", &fast_log);
m.def("fast_exp", &fast_exp);
m.def("fast_pow", &fast_pow);
m.def("fast_inverse", &fast_inverse);
m.def("fast_inverse_sqrt", &fast_inverse_sqrt);
m.def("floor", &floor);
m.def("ceil", &ceil);
m.def("round", &round);
m.def("trunc", &trunc);
m.def("fract", &fract);
m.def("is_nan", &is_nan);
m.def("reinterpret", (Expr (*)(Type, Expr)) &reinterpret);
m.def("cast", (Expr (*)(Type, Expr)) &cast);
m.def("print", [](py::args args) -> Expr {
return print(args_to_vector_for_print(args));
});
m.def("print_when", [](Expr condition, py::args args) -> Expr {
return print_when(condition, args_to_vector_for_print(args));
}, py::arg("condition"));
m.def("require", [](Expr condition, Expr value, py::args args) -> Expr {
auto v = args_to_vector<Expr>(args);
v.insert(v.begin(), value);
return require(condition, v);
}, py::arg("condition"), py::arg("value"));
m.def("lerp", &lerp);
m.def("popcount", &popcount);
m.def("count_leading_zeros", &count_leading_zeros);
m.def("count_trailing_zeros", &count_trailing_zeros);
m.def("div_round_to_zero", &div_round_to_zero);
m.def("mod_round_to_zero", &mod_round_to_zero);
m.def("random_float", (Expr (*)()) &random_float);
m.def("random_uint", (Expr (*)()) &random_uint);
m.def("random_int", (Expr (*)()) &random_int);
m.def("random_float", (Expr (*)(Expr)) &random_float, py::arg("seed"));
m.def("random_uint", (Expr (*)(Expr)) &random_uint, py::arg("seed"));
m.def("random_int", (Expr (*)(Expr)) &random_int, py::arg("seed"));
m.def("undef", (Expr (*)(Type)) &undef);
m.def("memoize_tag", [](Expr result, py::args cache_key_values) -> Expr {
return Internal::memoize_tag_helper(result, args_to_vector<Expr>(cache_key_values));
}, py::arg("result"));
m.def("likely", &likely);
m.def("likely_if_innermost", &likely_if_innermost);
m.def("saturating_cast", (Expr (*)(Type, Expr))&saturating_cast);
}
} // namespace PythonBindings
} // namespace Halide