https://github.com/halide/Halide
Raw File
Tip revision: 1006c4e5eb7eb391155e847e1d9839fe9df1568f authored by Andrew Adams on 31 May 2023, 21:40:28 UTC
Fix operator/ on ModulusRemainder
Tip revision: 1006c4e
CodeGen_RISCV.cpp
#include "CSE.h"
#include "CodeGen_Internal.h"
#include "CodeGen_Posix.h"
#include "ConciseCasts.h"
#include "Debug.h"
#include "IREquality.h"
#include "IRMatch.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "LLVM_Headers.h"
#include "Simplify.h"
#include "Substitute.h"
#include "Util.h"

namespace Halide {
namespace Internal {

using std::string;
using std::vector;

#if defined(WITH_RISCV)

namespace {

constexpr int max_intrinsic_args = 4;

struct IntrinsicArgPattern {
    enum TypePattern {
        Undefined,      // Invalid value for sentinel.
        Fixed,          // Argument is a fixed width vector.
        Scalable,       // Argument is a scalable vector.
        AllTypeWidths,  // Argument generalizes to all bit widths of type.
    } type_pattern;

    Type type;
    int relative_scale;

    IntrinsicArgPattern(const Type &type)
        : type_pattern(type.is_vector() ? Fixed : Scalable),
          type(type), relative_scale(1) {
    }
    IntrinsicArgPattern(halide_type_code_t code)
        : type_pattern(AllTypeWidths),
          type(code, 8, 1), relative_scale(1) {
    }
    IntrinsicArgPattern(halide_type_code_t code, int relative_scale)
        : type_pattern(AllTypeWidths),
          type(code, 8, 1), relative_scale(relative_scale) {
    }
    IntrinsicArgPattern()
        : type_pattern(Undefined),
          type(), relative_scale(0) {
    }
};

struct RISCVIntrinsic {
    const char *riscv_name;
    IntrinsicArgPattern ret_type;
    const char *name;
    IntrinsicArgPattern arg_types[max_intrinsic_args];
    int flags;
    enum {
        AddVLArg = 1 << 0,          // Add a constant full size vector length argument
        RoundDown = 1 << 1,         // Set vxrm rounding mode to down (rdn) before intrinsic.
        RoundUp = 1 << 2,           // Set vxrm rounding mode to up (rdu) before intrinsic.
        MangleReturnType = 1 << 3,  // Put return type mangling at start of type list.
        ReverseBinOp = 1 << 4,      // Switch first two arguments to handle asymmetric ops.
    };
};

Type concretize_fixed_or_scalable(const IntrinsicArgPattern &f_or_v, int type_width_scale, int vector_bits) {
    if (f_or_v.type_pattern == IntrinsicArgPattern::Fixed) {
        return f_or_v.type;
    }
    int bit_width = f_or_v.type.bits() * f_or_v.relative_scale * type_width_scale;
    return Type(f_or_v.type.code(), bit_width, (vector_bits * f_or_v.relative_scale) / bit_width);
}

// Produce LLVM IR intrisic type name mangling for Halide type, with vector codegen info provided.
std::string mangle_vector_argument_type(const Type &arg_type, bool scalable, int effective_vscale) {
    std::string result;
    if (arg_type.is_vector()) {
        int lanes = arg_type.lanes();
        if (!scalable) {
            result = "v" + std::to_string(lanes);
        } else {
            result = "nxv" + std::to_string(lanes / effective_vscale);
        }
    }
    if (arg_type.is_int() || arg_type.is_uint()) {
        result += "i";
    } else {
        result += "f";
    }
    result += std::to_string(arg_type.bits());

    return result;
}

/** A code generator that emits RISC-V code from a given Halide stmt. */
class CodeGen_RISCV : public CodeGen_Posix {
public:
    /** Create a RISC-V code generator. Processor features can be
     * enabled using the appropriate flags in the target struct. */
    CodeGen_RISCV(const Target &);
    llvm::Function *define_riscv_intrinsic_wrapper(const RISCVIntrinsic &intrin,
                                                   int type_width_scale);

protected:
    using CodeGen_Posix::visit;

    void init_module() override;

    string mcpu_target() const override;
    string mcpu_tune() const override;
    string mattrs() const override;
    string mabi() const override;
    bool use_soft_float_abi() const override;
    int native_vector_bits() const override;
    int maximum_vector_bits() const override;
    int target_vscale() const override;
};

CodeGen_RISCV::CodeGen_RISCV(const Target &t)
    : CodeGen_Posix(t) {
    use_llvm_vp_intrinsics = true;
}

string CodeGen_RISCV::mcpu_target() const {
    return "";
}

string CodeGen_RISCV::mcpu_tune() const {
    return mcpu_target();
}

string CodeGen_RISCV::mattrs() const {
    // Note: the default march is "rv[32|64]imafdc",
    // which includes standard extensions:
    //   +m Integer Multiplication and Division,
    //   +a Atomic Instructions,
    //   +f Single-Precision Floating-Point,
    //   +d Double-Precision Floating-Point,
    //   +c Compressed Instructions,
    string arch_flags = "+m,+a,+f,+d,+c";

    if (target.has_feature(Target::RVV)) {
        arch_flags += ",+v";
#if LLVM_VERSION >= 160
        if (target.vector_bits != 0) {
            arch_flags += ",+zvl" + std::to_string(target.vector_bits) + "b";
        }
#endif
    }
    return arch_flags;
}

string CodeGen_RISCV::mabi() const {
    string abi;
    if (target.bits == 32) {
        abi = "ilp32";
    } else {
        abi = "lp64";
    }
    if (!target.has_feature(Target::SoftFloatABI)) {
        abi += "d";
    }
    return abi;
}

bool CodeGen_RISCV::use_soft_float_abi() const {
    return target.has_feature(Target::SoftFloatABI);
}

int CodeGen_RISCV::native_vector_bits() const {
    if (target.vector_bits != 0 &&
        target.has_feature(Target::RVV)) {
        return target.vector_bits;
    }
    return 0;
}

int CodeGen_RISCV::maximum_vector_bits() const {
    return native_vector_bits() * 8;
}

int CodeGen_RISCV::target_vscale() const {
    if (target.vector_bits != 0 &&
        target.has_feature(Target::RVV)) {
        internal_assert((target.vector_bits % 64) == 0);
        return target.vector_bits / 64;
    }

    return 0;
}

const RISCVIntrinsic intrinsic_defs[] = {
    {"vaadd", Type::Int, "halving_add", {Type::Int, Type::Int}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::RoundDown},
    {"vaaddu", Type::UInt, "halving_add", {Type::UInt, Type::UInt}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::RoundDown},
    {"vaadd", Type::Int, "rounding_halving_add", {Type::Int, Type::Int}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::RoundUp},
    {"vaaddu", Type::UInt, "rounding_halving_add", {Type::UInt, Type::UInt}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::RoundUp},
    {"vwadd", {Type::Int, 2}, "widening_add", {Type::Int, Type::Int}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType},
    {"vwaddu", {Type::UInt, 2}, "widening_add", {Type::UInt, Type::UInt}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType},
    {"vwsub", {Type::Int, 2}, "widening_sub", {Type::Int, Type::Int}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType},
    {"vwsubu", {Type::UInt, 2}, "widening_sub", {Type::UInt, Type::UInt}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType},
    {"vwmul", {Type::Int, 2}, "widening_mul", {Type::Int, Type::Int}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType},
    {"vwmulu", {Type::UInt, 2}, "widening_mul", {Type::UInt, Type::UInt}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType},
    {"vwmulsu", {Type::Int, 2}, "widening_mul", {Type::Int, Type::UInt}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType},
    {"vwmulsu", {Type::Int, 2}, "widening_mul", {Type::UInt, Type::Int}, RISCVIntrinsic::AddVLArg | RISCVIntrinsic::MangleReturnType | RISCVIntrinsic::ReverseBinOp},
};

void CodeGen_RISCV::init_module() {
    CodeGen_Posix::init_module();

    int effective_vscale = target_vscale();
    if (effective_vscale != 0) {
        for (const RISCVIntrinsic &intrin : intrinsic_defs) {
            std::vector<Type> arg_types;
            arg_types.reserve(max_intrinsic_args);
            if (intrin.ret_type.type_pattern == IntrinsicArgPattern::AllTypeWidths) {
                // Iterate over 8/16/32/64 bit integer type widths via log2 shift amount.
                // TODO: Will need to add floating point bit widths when an intrinsic is added.
                //     Not doing this now as it is there would be no coverage, it requires
                //     deciding whether to get floatness from an argument or return type,
                //     and it probably has to check target flags to figure out Float(16)
                //     and BFloat(16) availability.
                for (int log2_of_scale = 0; log2_of_scale < 4; log2_of_scale++) {
                    int bit_width_scale = 1 << log2_of_scale;

                    Type ret_type = concretize_fixed_or_scalable(intrin.ret_type, bit_width_scale,
                                                                 target.vector_bits);

                    if ((intrin.ret_type.relative_scale * bit_width_scale * intrin.ret_type.type.bits()) > 64) {
                        break;
                    }

                    for (const auto &arg_type : intrin.arg_types) {
                        if (arg_type.type_pattern == IntrinsicArgPattern::Undefined) {
                            break;
                        }
                        if ((arg_type.relative_scale * bit_width_scale * arg_type.type.bits()) > 64) {
                            break;
                        }
                        arg_types.push_back(concretize_fixed_or_scalable(arg_type, bit_width_scale,
                                                                         target.vector_bits));
                    }
                    llvm::Function *intrin_impl = define_riscv_intrinsic_wrapper(intrin, bit_width_scale);
                    declare_intrin_overload(intrin.name, ret_type, intrin_impl, arg_types);
                    arg_types.clear();
                }
            } else {
                llvm::Function *intrin_impl = define_riscv_intrinsic_wrapper(intrin, 1);
                Type ret_type = concretize_fixed_or_scalable(intrin.ret_type, 1,
                                                             target.vector_bits);
                for (const auto &arg_type : intrin.arg_types) {
                    if (arg_type.type_pattern == IntrinsicArgPattern::Undefined) {
                        break;
                    }
                    arg_types.push_back(concretize_fixed_or_scalable(arg_type, 1, target.vector_bits));
                }
                declare_intrin_overload(intrin.name, ret_type, intrin_impl, arg_types);
                arg_types.clear();
            }
        }
    }
}

llvm::Function *CodeGen_RISCV::define_riscv_intrinsic_wrapper(const RISCVIntrinsic &intrin,
                                                              int bit_width_scale) {
    int effective_vscale = target_vscale();

    llvm::Type *xlen_type = target.bits == 32 ? i32_t : i64_t;

    // Produce intrinsic name and type mangling.
    std::vector<llvm::Type *> llvm_arg_types;
    std::string mangled_name = "llvm.riscv.";
    mangled_name += intrin.riscv_name;
    Type ret_type = concretize_fixed_or_scalable(intrin.ret_type, bit_width_scale,
                                                 target.vector_bits);
    if (intrin.flags & RISCVIntrinsic::MangleReturnType) {
        bool scalable = (intrin.ret_type.type_pattern != IntrinsicArgPattern::Fixed);
        mangled_name += "." + mangle_vector_argument_type(ret_type, scalable, effective_vscale);
    }
    llvm::Type *llvm_ret_type;
    if (ret_type.is_vector()) {
        int lanes = ret_type.lanes();
        bool scalable = (intrin.ret_type.type_pattern != IntrinsicArgPattern::Fixed);
        if (scalable) {
            lanes /= effective_vscale;
        }
        llvm_ret_type = llvm::VectorType::get(llvm_type_of(ret_type.element_of()),
                                              lanes, scalable);
    } else {
        llvm_ret_type = llvm_type_of(ret_type);
    }

    llvm_arg_types.push_back(llvm_ret_type);
    for (const auto &arg_type_pattern : intrin.arg_types) {
        if (arg_type_pattern.type_pattern == IntrinsicArgPattern::Undefined) {
            break;
        }
        Type arg_type = concretize_fixed_or_scalable(arg_type_pattern, bit_width_scale, target.vector_bits);

        bool scalable = (arg_type_pattern.type_pattern != IntrinsicArgPattern::Fixed);
        mangled_name += "." + mangle_vector_argument_type(arg_type, scalable, effective_vscale);
        llvm::Type *llvm_type;
        if (arg_type.is_vector()) {
            int lanes = arg_type.lanes();
            if (scalable) {
                lanes /= effective_vscale;
            }
            llvm_type = llvm::VectorType::get(llvm_type_of(arg_type.element_of()),
                                              lanes, scalable);
        } else {
            llvm_type = llvm_type_of(arg_type);
        }
        llvm_arg_types.push_back(llvm_type);
    }
    if (intrin.flags & RISCVIntrinsic::ReverseBinOp) {
        internal_assert(llvm_arg_types.size() > 2);
        std::swap(llvm_arg_types[1], llvm_arg_types[2]);
    }
    if (intrin.flags & RISCVIntrinsic::AddVLArg) {
        mangled_name += (target.bits == 64) ? ".i64" : ".i32";
        llvm_arg_types.push_back(xlen_type);
    }

    llvm::Function *inner =
        get_llvm_intrin(llvm_ret_type, mangled_name, llvm_arg_types);
    llvm::FunctionType *inner_ty = inner->getFunctionType();

    // Remove vector tail preservation argument.
    llvm_arg_types.erase(llvm_arg_types.begin());
    // Remove vector length argument passed to inrinsic for wrapper.
    // Wrapper will supply a constant for the fixed vector length.
    if (intrin.flags & RISCVIntrinsic::AddVLArg) {
        llvm_arg_types.resize(llvm_arg_types.size() - 1);
    }

    string wrapper_name = unique_name(std::string(intrin.name) + "_wrapper");
    llvm::FunctionType *wrapper_ty = llvm::FunctionType::get(
        inner_ty->getReturnType(), llvm_arg_types, false);
    llvm::Function *wrapper =
        llvm::Function::Create(wrapper_ty, llvm::GlobalValue::InternalLinkage,
                               wrapper_name, module.get());
    llvm::BasicBlock *block =
        llvm::BasicBlock::Create(module->getContext(), "entry", wrapper);
    llvm::IRBuilderBase::InsertPoint here = builder->saveIP();
    builder->SetInsertPoint(block);

    // Set vector fixed-point rounding flag if needed for intrinsic.
    bool round_down = intrin.flags & RISCVIntrinsic::RoundDown;
    bool round_up = intrin.flags & RISCVIntrinsic::RoundUp;
    if (round_down || round_up) {
        internal_assert(!(round_down && round_up));
        llvm::Value *rounding_mode = llvm::ConstantInt::get(xlen_type, round_down ? 2 : 0);
        // See https://github.com/riscv/riscv-v-spec/releases/download/v1.0/riscv-v-spec-1.0.pdf page 15
        // for discussion of fixed-point rounding mode.
        // TODO: When LLVM finally fixes the instructions to take rounding modes,
        // this will have to change to passing the rounding mode to the intrinsic.
        // https://github.com/halide/Halide/issues/7123
        llvm::FunctionType *csrw_llvm_type = llvm::FunctionType::get(void_t, {xlen_type}, false);
        llvm::InlineAsm *inline_csrw = llvm::InlineAsm::get(csrw_llvm_type, "csrw vxrm,${0:z}", "rJ,~{memory}", true);
        builder->CreateCall(inline_csrw, {rounding_mode});
    }

    // Call the LLVM intrinsic.
    int actual_lanes = ret_type.lanes();
    llvm::Constant *vtype = llvm::ConstantInt::get(xlen_type, actual_lanes);
    // Add an initial argument to handle tail propagation. Only done if result is vector type.
    int left_arg = 0;
    int right_arg = 1;
    if (intrin.flags & RISCVIntrinsic::ReverseBinOp) {
        std::swap(left_arg, right_arg);
    }
    llvm::Value *ret = builder->CreateCall(inner, {llvm::UndefValue::get(llvm_ret_type),
                                                   wrapper->getArg(left_arg), wrapper->getArg(right_arg),
                                                   vtype});
    builder->CreateRet(ret);

    // Always inline these wrappers.
    wrapper->addFnAttr(llvm::Attribute::AlwaysInline);

    builder->restoreIP(here);

    function_does_not_access_memory(wrapper);
    wrapper->addFnAttr(llvm::Attribute::NoUnwind);

    llvm::verifyFunction(*wrapper);
    return wrapper;
}

}  // anonymous namespace

std::unique_ptr<CodeGen_Posix> new_CodeGen_RISCV(const Target &target) {
    return std::make_unique<CodeGen_RISCV>(target);
}

#else  // WITH_RISCV

std::unique_ptr<CodeGen_Posix> new_CodeGen_RISCV(const Target &target) {
    user_error << "RISCV not enabled for this build of Halide.\n";
    return nullptr;
}

#endif  // WITH_RISCV

}  // namespace Internal
}  // namespace Halide
back to top