// This file is a part of Julia. License is MIT: https://julialang.org/license #define DEBUG_TYPE "combine_muladd" #undef DEBUG #include "llvm-version.h" #include #include #include #include #include #include #include #include #include #include "julia.h" #include "julia_assert.h" using namespace llvm; /** * Combine * ``` * %v0 = fmul ... %a, %b * %v = fadd fast ... %v0, %c * ``` * to * `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)` * when `%v0` has no other use */ struct CombineMulAdd : public FunctionPass { static char ID; CombineMulAdd() : FunctionPass(ID) {} private: bool runOnFunction(Function &F) override; }; // Return true if this function shouldn't be called again on the other operand // This will always return false on LLVM 5.0+ static bool checkCombine(Module *m, Instruction *addOp, Value *maybeMul, Value *addend, bool negadd, bool negres) { auto mulOp = dyn_cast(maybeMul); if (!mulOp || mulOp->getOpcode() != Instruction::FMul) return false; if (!mulOp->hasOneUse()) return false; #if JL_LLVM_VERSION >= 50000 // On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us. auto fmf = mulOp->getFastMathFlags(); fmf.setAllowContract(true); mulOp->copyFastMathFlags(fmf); return false; #else IRBuilder<> builder(m->getContext()); builder.SetInsertPoint(addOp); auto mul1 = mulOp->getOperand(0); auto mul2 = mulOp->getOperand(1); Value *muladdf = Intrinsic::getDeclaration(m, Intrinsic::fmuladd, addOp->getType()); if (negadd) { auto newaddend = builder.CreateFNeg(addend); // Might be a const if (auto neginst = dyn_cast(newaddend)) neginst->setHasUnsafeAlgebra(true); addend = newaddend; } Instruction *newv = builder.CreateCall(muladdf, {mul1, mul2, addend}); newv->setHasUnsafeAlgebra(true); if (negres) { // Shouldn't be a constant newv = cast(builder.CreateFNeg(newv)); newv->setHasUnsafeAlgebra(true); } addOp->replaceAllUsesWith(newv); addOp->eraseFromParent(); mulOp->eraseFromParent(); return true; #endif } bool CombineMulAdd::runOnFunction(Function &F) { Module *m = F.getParent(); for (auto &BB: F) { for (auto it = BB.begin(); it != BB.end();) { auto &I = *it; it++; switch (I.getOpcode()) { case Instruction::FAdd: { #if JL_LLVM_VERSION >= 60000 if (!I.isFast()) #else if (!I.hasUnsafeAlgebra()) #endif continue; checkCombine(m, &I, I.getOperand(0), I.getOperand(1), false, false) || checkCombine(m, &I, I.getOperand(1), I.getOperand(0), false, false); break; } case Instruction::FSub: { #if JL_LLVM_VERSION >= 60000 if (!I.isFast()) #else if (!I.hasUnsafeAlgebra()) #endif continue; checkCombine(m, &I, I.getOperand(0), I.getOperand(1), true, false) || checkCombine(m, &I, I.getOperand(1), I.getOperand(0), true, true); break; } default: break; } } } return true; } char CombineMulAdd::ID = 0; static RegisterPass X("CombineMulAdd", "Combine mul and add to muladd", false /* Only looks at CFG */, false /* Analysis Pass */); Pass *createCombineMulAddPass() { return new CombineMulAdd(); }