https://github.com/halide/Halide
Raw File
Tip revision: 2cd9c914b6bc29672e369850a5cc750294585306 authored by Steven Johnson on 31 March 2020, 23:46:32 UTC
WIP
Tip revision: 2cd9c91
UnsafePromises.cpp
#include "UnsafePromises.h"
#include "IRMutator.h"
#include "IROperator.h"

namespace Halide {
namespace Internal {

namespace {

class LowerUnsafePromises : public IRMutator {
    using IRMutator::visit;

    Expr visit(const Call *op) override {
        if (op->is_intrinsic(Call::unsafe_promise_clamped)) {
            if (check) {
                Expr is_clamped = op->args[0] >= op->args[1] && op->args[0] <= op->args[2];
                std::ostringstream promise_expr_text;
                promise_expr_text << is_clamped;
                Expr cond_as_string = StringImm::make(promise_expr_text.str());
                Expr promise_broken_error =
                    Call::make(Int(32),
                               "halide_error_requirement_failed",
                               {cond_as_string, StringImm::make("from unsafe_promise_clamped")},
                               Call::Extern);
                return Call::make(op->args[0].type(),
                                  Call::require,
                                  {mutate(is_clamped), mutate(op->args[0]), promise_broken_error},
                                  Call::PureIntrinsic);
            } else {
                return mutate(op->args[0]);
            }
        } else {
            return IRMutator::visit(op);
        }
    }

    bool check;

public:
    LowerUnsafePromises(bool check)
        : check(check) {
    }
};

class LowerSafePromises : public IRMutator {
    using IRMutator::visit;

    Expr visit(const Call *op) override {
        if (op->is_intrinsic(Call::promise_clamped)) {
            return mutate(op->args[0]);
        } else {
            return IRMutator::visit(op);
        }
    }
};

}  // namespace

Stmt lower_unsafe_promises(const Stmt &s, const Target &t) {
    return LowerUnsafePromises(t.has_feature(Target::CheckUnsafePromises)).mutate(s);
}

Stmt lower_safe_promises(const Stmt &s) {
    return LowerSafePromises().mutate(s);
}

}  // namespace Internal
}  // namespace Halide
back to top