// This file is a part of Julia. License is MIT: https://julialang.org/license #include "llvm-version.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "codegen_shared.h" #include "julia.h" #define DEBUG_TYPE "propagate_julia_addrspaces" using namespace llvm; /* This pass performs propagation of addrspace information that is legal from the frontend definition, but illegal by general IR semantics. In particular, this includes: - Changing the address space of a load/store if the base pointer is in an untracked address space - Commuting GEPs and addrspace casts This is most useful for removing superfluous casts that can inhibit LLVM optimizations. */ struct PropagateJuliaAddrspaces : public FunctionPass, public InstVisitor { static char ID; DenseMap LiftingMap; SmallPtrSet Visited; std::vector ToDelete; std::vector> ToInsert; PropagateJuliaAddrspaces() : FunctionPass(ID) {}; public: bool runOnFunction(Function &F) override; Value *LiftPointer(Value *V, Type *LocTy = nullptr, Instruction *InsertPt=nullptr); void visitMemop(Instruction &I, Type *T, unsigned OpIndex); void visitLoadInst(LoadInst &LI); void visitStoreInst(StoreInst &SI); void visitAtomicCmpXchgInst(AtomicCmpXchgInst &SI); void visitAtomicRMWInst(AtomicRMWInst &SI); void visitMemSetInst(MemSetInst &MI); void visitMemTransferInst(MemTransferInst &MTI); private: void PoisonValues(std::vector &Worklist); }; bool PropagateJuliaAddrspaces::runOnFunction(Function &F) { visit(F); for (auto it : ToInsert) it.first->insertBefore(it.second); for (Instruction *I : ToDelete) I->eraseFromParent(); ToInsert.clear(); ToDelete.clear(); LiftingMap.clear(); Visited.clear(); return true; } static unsigned getValueAddrSpace(Value *V) { return cast(V->getType())->getAddressSpace(); } static bool isSpecialAS(unsigned AS) { return AddressSpace::FirstSpecial <= AS && AS <= AddressSpace::LastSpecial; } void PropagateJuliaAddrspaces::PoisonValues(std::vector &Worklist) { while (!Worklist.empty()) { Value *CurrentV = Worklist.back(); Worklist.pop_back(); for (Value *User : CurrentV->users()) { if (Visited.count(User)) continue; Visited.insert(CurrentV); Worklist.push_back(User); } } } Value *PropagateJuliaAddrspaces::LiftPointer(Value *V, Type *LocTy, Instruction *InsertPt) { SmallVector Stack; std::vector Worklist; std::set LocalVisited; Worklist.push_back(V); // Follow pointer casts back, see if we're based on a pointer in // an untracked address space, in which case we're allowed to drop // intermediate addrspace casts. while (!Worklist.empty()) { Value *CurrentV = Worklist.back(); Worklist.pop_back(); if (LocalVisited.count(CurrentV)) { continue; } while (true) { if (auto *BCI = dyn_cast(CurrentV)) CurrentV = BCI->getOperand(0); else if (auto *ACI = dyn_cast(CurrentV)) { CurrentV = ACI->getOperand(0); if (!isSpecialAS(getValueAddrSpace(ACI))) break; } else if (auto *GEP = dyn_cast(CurrentV)) { if (LiftingMap.count(GEP)) { CurrentV = LiftingMap[GEP]; break; } else if (Visited.count(GEP)) { return nullptr; } Stack.push_back(GEP); LocalVisited.insert(GEP); CurrentV = GEP->getOperand(0); } else if (auto *Phi = dyn_cast(CurrentV)) { if (LiftingMap.count(Phi)) { break; } for (Value *Incoming : Phi->incoming_values()) { Worklist.push_back(Incoming); } Stack.push_back(Phi); LocalVisited.insert(Phi); break; } else if (auto *Select = dyn_cast(CurrentV)) { if (LiftingMap.count(Select)) { break; } else if (Visited.count(Select)) { return nullptr; } // Push one of the branches onto the worklist, continue with the other one // directly Worklist.push_back(Select->getOperand(2)); Stack.push_back(Select); LocalVisited.insert(Select); CurrentV = Select->getOperand(1); } else if (isa(CurrentV)) { // It's always legal to lift null pointers into any address space break; } else { // Ok, we've reached a leaf - check if it is eligible for lifting if (!CurrentV->getType()->isPointerTy() || isSpecialAS(getValueAddrSpace(CurrentV))) { // If not, poison all (recursive) users of this value, to prevent // looking at them again in future iterations. Worklist.clear(); Worklist.push_back(CurrentV); Visited.insert(CurrentV); PoisonValues(Worklist); return nullptr; } break; } } } // Go through and insert lifted versions of all instructions on the list. std::vector ToRevisit; for (Value *V : Stack) { if (LiftingMap.count(V)) continue; if (isa(V) || isa(V) || isa(V)) { Instruction *InstV = cast(V); Instruction *NewV = InstV->clone(); ToInsert.push_back(std::make_pair(NewV, InstV)); Type *NewRetTy = cast(InstV->getType())->getElementType()->getPointerTo(0); NewV->mutateType(NewRetTy); LiftingMap[InstV] = NewV; ToRevisit.push_back(NewV); } } auto CollapseCastsAndLift = [&](Value *CurrentV, Instruction *InsertPt) -> Value * { PointerType *TargetType = cast(CurrentV->getType())->getElementType()->getPointerTo(0); while (!LiftingMap.count(CurrentV)) { if (isa(CurrentV)) CurrentV = cast(CurrentV)->getOperand(0); else if (isa(CurrentV)) CurrentV = cast(CurrentV)->getOperand(0); else break; } if (isa(CurrentV)) { return ConstantPointerNull::get(TargetType); } if (LiftingMap.count(CurrentV)) CurrentV = LiftingMap[CurrentV]; if (CurrentV->getType() != TargetType) { auto *BCI = new BitCastInst(CurrentV, TargetType); ToInsert.push_back(std::make_pair(BCI, InsertPt)); CurrentV = BCI; } return CurrentV; }; // Now go through and update the operands for (Value *V : ToRevisit) { if (GetElementPtrInst *NewGEP = dyn_cast(V)) { NewGEP->setOperand(GetElementPtrInst::getPointerOperandIndex(), CollapseCastsAndLift(NewGEP->getOperand(GetElementPtrInst::getPointerOperandIndex()), NewGEP)); } else if (PHINode *NewPhi = dyn_cast(V)) { for (size_t i = 0; i < NewPhi->getNumIncomingValues(); ++i) { NewPhi->setIncomingValue(i, CollapseCastsAndLift(NewPhi->getIncomingValue(i), NewPhi->getIncomingBlock(i)->getTerminator())); } } else if (SelectInst *NewSelect = dyn_cast(V)) { NewSelect->setOperand(1, CollapseCastsAndLift(NewSelect->getOperand(1), NewSelect)); NewSelect->setOperand(2, CollapseCastsAndLift(NewSelect->getOperand(2), NewSelect)); } else { assert(false && "Shouldn't have reached here"); } } return CollapseCastsAndLift(V, InsertPt); } void PropagateJuliaAddrspaces::visitMemop(Instruction &I, Type *T, unsigned OpIndex) { Value *Original = I.getOperand(OpIndex); unsigned AS = Original->getType()->getPointerAddressSpace(); if (!isSpecialAS(AS)) return; Value *Replacement = LiftPointer(Original, T, &I); if (!Replacement) return; I.setOperand(OpIndex, Replacement); } void PropagateJuliaAddrspaces::visitLoadInst(LoadInst &LI) { visitMemop(LI, LI.getType(), LoadInst::getPointerOperandIndex()); } void PropagateJuliaAddrspaces::visitStoreInst(StoreInst &SI) { visitMemop(SI, SI.getValueOperand()->getType(), StoreInst::getPointerOperandIndex()); } void PropagateJuliaAddrspaces::visitAtomicCmpXchgInst(AtomicCmpXchgInst &SI) { visitMemop(SI, SI.getNewValOperand()->getType(), AtomicCmpXchgInst::getPointerOperandIndex()); } void PropagateJuliaAddrspaces::visitAtomicRMWInst(AtomicRMWInst &SI) { visitMemop(SI, SI.getType(), AtomicRMWInst::getPointerOperandIndex()); } void PropagateJuliaAddrspaces::visitMemSetInst(MemSetInst &MI) { unsigned AS = MI.getDestAddressSpace(); if (!isSpecialAS(AS)) return; Value *Replacement = LiftPointer(MI.getRawDest()); if (!Replacement) return; Function *TheFn = Intrinsic::getDeclaration(MI.getModule(), Intrinsic::memset, {Replacement->getType(), MI.getOperand(1)->getType()}); MI.setCalledFunction(TheFn); MI.setArgOperand(0, Replacement); } void PropagateJuliaAddrspaces::visitMemTransferInst(MemTransferInst &MTI) { unsigned DestAS = MTI.getDestAddressSpace(); unsigned SrcAS = MTI.getSourceAddressSpace(); if (!isSpecialAS(DestAS) && !isSpecialAS(SrcAS)) return; Value *Dest = MTI.getRawDest(); if (isSpecialAS(DestAS)) { Value *Replacement = LiftPointer(Dest, cast(Dest->getType())->getElementType(), &MTI); if (Replacement) Dest = Replacement; } Value *Src = MTI.getRawSource(); if (isSpecialAS(SrcAS)) { Value *Replacement = LiftPointer(Src, cast(Src->getType())->getElementType(), &MTI); if (Replacement) Src = Replacement; } if (Dest == MTI.getRawDest() && Src == MTI.getRawSource()) return; Function *TheFn = Intrinsic::getDeclaration(MTI.getModule(), MTI.getIntrinsicID(), {Dest->getType(), Src->getType(), MTI.getOperand(2)->getType()}); MTI.setCalledFunction(TheFn); MTI.setArgOperand(0, Dest); MTI.setArgOperand(1, Src); } char PropagateJuliaAddrspaces::ID = 0; static RegisterPass X("PropagateJuliaAddrspaces", "Propagate (non-)rootedness information", false, false); Pass *createPropagateJuliaAddrspaces() { return new PropagateJuliaAddrspaces(); } extern "C" JL_DLLEXPORT void LLVMExtraAddPropagateJuliaAddrspaces(LLVMPassManagerRef PM) { unwrap(PM)->add(createPropagateJuliaAddrspaces()); }