Raw File
ClampUnsafeAccesses.cpp
#include "ClampUnsafeAccesses.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "IRPrinter.h"
#include "Simplify.h"

namespace Halide::Internal {

namespace {

struct ClampUnsafeAccesses : IRMutator {
    const std::map<std::string, Function> &env;
    FuncValueBounds &func_bounds;
    std::vector<std::string> realizes_inside_current_producer;

    ClampUnsafeAccesses(const std::map<std::string, Function> &env, FuncValueBounds &func_bounds)
        : env(env), func_bounds(func_bounds) {
    }

    bool is_realize_inside_current_producer(const std::string &n) const {
        return std::find(realizes_inside_current_producer.begin(),
                         realizes_inside_current_producer.end(), n) != realizes_inside_current_producer.end();
    }

protected:
    using IRMutator::visit;

    Stmt visit(const ProducerConsumer *op) override {
        if (op->is_producer) {
            ScopedValue<std::vector<std::string>> r(realizes_inside_current_producer, {});
            return IRMutator::visit(op);
        } else {
            return IRMutator::visit(op);
        }
    }

    Stmt visit(const Realize *op) override {
        realizes_inside_current_producer.push_back(op->name);
        auto new_stmt = IRMutator::visit(op);
        realizes_inside_current_producer.pop_back();
        return new_stmt;
    }

    Expr visit(const Let *let) override {
        return visit_let<Let, Expr>(let);
    }

    Stmt visit(const LetStmt *let) override {
        return visit_let<LetStmt, Stmt>(let);
    }

    Expr visit(const Variable *var) override {
        if (is_inside_indexing && let_var_inside_indexing.contains(var->name)) {
            let_var_inside_indexing.ref(var->name) = true;
        }
        return var;
    }

    Expr visit(const Call *call) override {
        // Note that if the call's realization is inside the current producer (i.e. it's caller),
        // the compute bounds of this call should be known to cover all loaded values, so we
        // should be able to safely skip the clamp injection (see #6297).
        if (call->call_type == Call::Halide && is_inside_indexing && !is_realize_inside_current_producer(call->name)) {
            auto bounds = func_bounds.at({call->name, call->value_index});
            if (bounds_smaller_than_type(bounds, call->type)) {
                // TODO(#6297): check that the clamped function's allocation bounds might be wider than its compute bounds
                auto [new_args, changed] = mutate_with_changes(call->args);
                Expr new_call = !changed ? call : Call::make(call->type, call->name, new_args, call->call_type, call->func, call->value_index, call->image, call->param);
                return Max::make(Min::make(new_call, std::move(bounds.max)), std::move(bounds.min));
            }
        }

        ScopedValue<bool> s(is_inside_indexing,
                            (is_inside_indexing ||
                             call->call_type == Call::Halide ||
                             call->call_type == Call::Image));
        return IRMutator::visit(call);
    }

private:
    template<typename L, typename Body>
    Body visit_let(const L *let) {
        ScopedBinding<bool> binding(let_var_inside_indexing, let->name, false);
        Body body = mutate(let->body);

        ScopedValue<bool> s(is_inside_indexing, is_inside_indexing || let_var_inside_indexing.get(let->name));
        Expr value = mutate(let->value);

        return L::make(let->name, std::move(value), std::move(body));
    }

    bool bounds_smaller_than_type(const Interval &bounds, Type type) {
        return bounds.is_bounded() && !(equal(bounds.min, type.min()) && equal(bounds.max, type.max()));
    }

    /**
     * A let-var is marked "true" if is used somewhere in an indexing expression.
     * visit_let will process its value binding with is_inside_indexing set when
     * this is the case.
     */
    Scope<bool> let_var_inside_indexing;
    bool is_inside_indexing = false;
};

}  // namespace

Stmt clamp_unsafe_accesses(const Stmt &s, const std::map<std::string, Function> &env, FuncValueBounds &func_bounds) {
    return ClampUnsafeAccesses(env, func_bounds).mutate(s);
}

}  // namespace Halide::Internal
back to top