Revision 3657cf5f363fd64aeaf06432e62e3960800927b0 authored by Andrew Adams on 26 January 2024, 17:26:12 UTC, committed by GitHub on 26 January 2024, 17:26:12 UTC
* Fix bounds_of_nested_lanes

bounds_of_nested_lanes assumed that one layer of nested vectorization
could be removed at a time. When faced with the expression:

min(ramp(x8(a), x8(b), 5), x40(27))

It panicked, because on the left hand side it reduced the bounds to
x8(a) ... x8(a) + x8(b) * 4, and on the right hand side it reduced the
bounds to 27. It then attempted to take a min of mismatched types.

In general we can't assume that binary operators on nested vectors have
the same nesting structure on both sides, so I just rewrote it to reduce
directly to a scalar.

Fixes #8038
1 parent 4590a09
Raw File
SelectGPUAPI.cpp
#include "SelectGPUAPI.h"
#include "DeviceInterface.h"
#include "IRMutator.h"

namespace Halide {
namespace Internal {

namespace {

class SelectGPUAPI : public IRMutator {
    using IRMutator::visit;

    DeviceAPI default_api, parent_api;

    Expr visit(const Call *op) override {
        if (op->name == "halide_default_device_interface") {
            return make_device_interface_call(default_api);
        } else {
            return IRMutator::visit(op);
        }
    }

    Stmt visit(const For *op) override {
        DeviceAPI selected_api = op->device_api;
        if (op->device_api == DeviceAPI::Default_GPU) {
            selected_api = default_api;
        }

        DeviceAPI old_parent_api = parent_api;
        parent_api = selected_api;
        Stmt stmt = IRMutator::visit(op);
        parent_api = old_parent_api;

        op = stmt.as<For>();
        internal_assert(op);

        if (op->device_api != selected_api) {
            return For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, selected_api, op->body);
        }
        return stmt;
    }

public:
    SelectGPUAPI(const Target &t) {
        default_api = get_default_device_api_for_target(t);
        parent_api = DeviceAPI::Host;
    }
};

}  // namespace

Stmt select_gpu_api(const Stmt &s, const Target &t) {
    return SelectGPUAPI(t).mutate(s);
}

}  // namespace Internal
}  // namespace Halide
back to top