swh:1:snp:2c68c8bd649bf1bd2cf3bf7bd4f98d247b82b5dc
Raw File
Tip revision: d0cdc156d85f67bf1e20753e1a38b9f57d6c8582 authored by Andrew Adams on 27 October 2021, 00:51:39 UTC
Improve comments
Tip revision: d0cdc15
DeviceInterface.cpp
#include "DeviceInterface.h"
#include "IR.h"
#include "IROperator.h"
#include "JITModule.h"
#include "Target.h"
#include "runtime/HalideBuffer.h"

using namespace Halide;
using namespace Halide::Internal;

namespace Halide {

namespace {

template<typename fn_type>
bool lookup_runtime_routine(const std::string &name,
                            const Target &target,
                            fn_type &result) {
    std::vector<JITModule> runtime(
        JITSharedRuntime::get(nullptr, target.with_feature(Target::JIT)));

    for (const auto &module : runtime) {
        std::map<std::string, JITModule::Symbol>::const_iterator f =
            module.exports().find(name);
        if (f != module.exports().end()) {
            result = reinterpret_bits<fn_type>(f->second.address);
            return true;
        }
    }
    return false;
}

}  // namespace

bool host_supports_target_device(const Target &t) {
    const DeviceAPI d = t.get_required_device_api();
    if (d == DeviceAPI::None) {
        // If the target requires no DeviceAPI, then
        // the host trivially supports the target device.
        return true;
    }

    const struct halide_device_interface_t *i = get_device_interface_for_device_api(d, t);
    if (!i) {
        debug(1) << "host_supports_device_api: get_device_interface_for_device_api() failed for d=" << (int)d << " t=" << t << "\n";
        return false;
    }

    Halide::Runtime::Buffer<uint8_t> temp(8, 8, 3);
    temp.fill(0);
    temp.set_host_dirty();

    Halide::JITHandlers handlers;
    handlers.custom_error = [](JITUserContext *user_context, const char *msg) {
        debug(1) << "host_supports_device_api: saw error (" << msg << ")\n";
    };
    Halide::JITHandlers old_handlers = Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);

    int result = temp.copy_to_device(i);

    Halide::Internal::JITSharedRuntime::set_default_handlers(old_handlers);

    if (result != 0) {
        debug(1) << "host_supports_device_api: copy_to_device() failed for with result=" << result << " for d=" << (int)d << " t=" << t << "\n";
        return false;
    }
    return true;
}

const halide_device_interface_t *get_device_interface_for_device_api(DeviceAPI d,
                                                                     const Target &t,
                                                                     const char *error_site) {

    if (d == DeviceAPI::Default_GPU) {
        d = get_default_device_api_for_target(t);
        if (d == DeviceAPI::Host) {
            if (error_site) {
                user_error
                    << "get_device_interface_for_device_api called from "
                    << error_site
                    << " requested a default GPU but no GPU feature is specified in target ("
                    << t.to_string()
                    << ").\n";
            }
            return nullptr;
        }
    }

    const struct halide_device_interface_t *(*fn)();
    std::string name;
    if (d == DeviceAPI::Metal) {
        name = "metal";
    } else if (d == DeviceAPI::OpenCL) {
        name = "opencl";
    } else if (d == DeviceAPI::CUDA) {
        name = "cuda";
    } else if (d == DeviceAPI::OpenGLCompute) {
        name = "openglcompute";
    } else if (d == DeviceAPI::Hexagon) {
        name = "hexagon";
    } else if (d == DeviceAPI::HexagonDma) {
        name = "hexagon_dma";
    } else if (d == DeviceAPI::D3D12Compute) {
        name = "d3d12compute";
    } else {
        if (error_site) {
            user_error
                << "get_device_interface_for_device_api called from "
                << error_site
                << " requested unknown DeviceAPI ("
                << (int)d
                << ").\n";
        }
        return nullptr;
    }

    if (!t.supports_device_api(d)) {
        if (error_site) {
            user_error
                << "get_device_interface_for_device_api called from "
                << error_site
                << " DeviceAPI ("
                << name
                << ") is not supported by target ("
                << t.to_string()
                << ").\n";
        }
        return nullptr;
    }

    if (lookup_runtime_routine("halide_" + name + "_device_interface", t, fn)) {
        return (*fn)();
    } else {
        if (error_site) {
            user_error
                << "get_device_interface_for_device_api called from "
                << error_site
                << " cannot find runtime or device interface symbol for "
                << name
                << ".\n";
        }
        return nullptr;
    }
}

DeviceAPI get_default_device_api_for_target(const Target &target) {
    if (target.has_feature(Target::Metal)) {
        return DeviceAPI::Metal;
    } else if (target.has_feature(Target::OpenCL)) {
        return DeviceAPI::OpenCL;
    } else if (target.has_feature(Target::CUDA)) {
        return DeviceAPI::CUDA;
    } else if (target.has_feature(Target::OpenGLCompute)) {
        return DeviceAPI::OpenGLCompute;
    } else if (target.arch != Target::Hexagon && target.has_feature(Target::HVX)) {
        return DeviceAPI::Hexagon;
    } else if (target.has_feature(Target::HexagonDma)) {
        return DeviceAPI::HexagonDma;
    } else if (target.has_feature(Target::D3D12Compute)) {
        return DeviceAPI::D3D12Compute;
    } else {
        return DeviceAPI::Host;
    }
}

namespace Internal {
Expr make_device_interface_call(DeviceAPI device_api, MemoryType memory_type) {
    if (device_api == DeviceAPI::Host) {
        return make_zero(type_of<const halide_device_interface_t *>());
    }

    std::string interface_name;
    switch (device_api) {
    case DeviceAPI::CUDA:
        interface_name = "halide_cuda_device_interface";
        break;
    case DeviceAPI::OpenCL:
        if (memory_type == MemoryType::GPUTexture) {
            interface_name = "halide_opencl_image_device_interface";
        } else {
            interface_name = "halide_opencl_device_interface";
        }
        break;
    case DeviceAPI::Metal:
        interface_name = "halide_metal_device_interface";
        break;
    case DeviceAPI::OpenGLCompute:
        interface_name = "halide_openglcompute_device_interface";
        break;
    case DeviceAPI::Hexagon:
        interface_name = "halide_hexagon_device_interface";
        break;
    case DeviceAPI::HexagonDma:
        interface_name = "halide_hexagon_dma_device_interface";
        break;
    case DeviceAPI::D3D12Compute:
        interface_name = "halide_d3d12compute_device_interface";
        break;
    case DeviceAPI::Default_GPU:
        // Will be resolved later
        interface_name = "halide_default_device_interface";
        break;
    default:
        internal_error << "Bad DeviceAPI " << static_cast<int>(device_api) << "\n";
        break;
    }
    return Call::make(type_of<const halide_device_interface_t *>(), interface_name, {}, Call::Extern);
}
}  // namespace Internal

}  // namespace Halide
back to top