//===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass replaces all the uses of LDS within non-kernel functions by // corresponding pointer counter-parts. // // The main motivation behind this pass is - to *avoid* subsequent LDS lowering // pass from directly packing LDS (assume large LDS) into a struct type which // would otherwise cause allocating huge memory for struct instance within every // kernel. // // Brief sketch of the algorithm implemented in this pass is as below: // // 1. Collect all the LDS defined in the module which qualify for pointer // replacement, say it is, LDSGlobals set. // // 2. Collect all the reachable callees for each kernel defined in the module, // say it is, KernelToCallees map. // // 3. FOR (each global GV from LDSGlobals set) DO // LDSUsedNonKernels = Collect all non-kernel functions which use GV. // FOR (each kernel K in KernelToCallees map) DO // ReachableCallees = KernelToCallees[K] // ReachableAndLDSUsedCallees = // SetIntersect(LDSUsedNonKernels, ReachableCallees) // IF (ReachableAndLDSUsedCallees is not empty) THEN // Pointer = Create a pointer to point-to GV if not created. // Initialize Pointer to point-to GV within kernel K. // ENDIF // ENDFOR // Replace all uses of GV within non kernel functions by Pointer. // ENFOR // // LLVM IR example: // // Input IR: // // @lds = internal addrspace(3) global [4 x i32] undef, align 16 // // define internal void @f0() { // entry: // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds, // i32 0, i32 0 // ret void // } // // define protected amdgpu_kernel void @k0() { // entry: // call void @f0() // ret void // } // // Output IR: // // @lds = internal addrspace(3) global [4 x i32] undef, align 16 // @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2 // // define internal void @f0() { // entry: // %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2 // %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0 // %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)* // %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2, // i32 0, i32 0 // ret void // } // // define protected amdgpu_kernel void @k0() { // entry: // store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16), // i16 addrspace(3)* @lds.ptr, align 2 // call void @f0() // ret void // } // //===----------------------------------------------------------------------===// #include "AMDGPU.h" #include "GCNSubtarget.h" #include "Utils/AMDGPUBaseInfo.h" #include "Utils/AMDGPULDSUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include #include #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer" using namespace llvm; namespace { class ReplaceLDSUseImpl { Module &M; LLVMContext &Ctx; const DataLayout &DL; Constant *LDSMemBaseAddr; DenseMap LDSToPointer; DenseMap> LDSToNonKernels; DenseMap> KernelToCallees; DenseMap> KernelToLDSPointers; DenseMap KernelToInitBB; DenseMap> FunctionToLDSToReplaceInst; // Collect LDS which requires their uses to be replaced by pointer. std::vector collectLDSRequiringPointerReplace() { // Collect LDS which requires module lowering. std::vector LDSGlobals = AMDGPU::findVariablesToLower(M); // Remove LDS which don't qualify for replacement. LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(), [&](GlobalVariable *GV) { return shouldIgnorePointerReplacement(GV); }), LDSGlobals.end()); return LDSGlobals; } // Returns true if uses of given LDS global within non-kernel functions should // be keep as it is without pointer replacement. bool shouldIgnorePointerReplacement(GlobalVariable *GV) { // LDS whose size is very small and doesn`t exceed pointer size is not worth // replacing. if (DL.getTypeAllocSize(GV->getValueType()) <= 2) return true; // LDS which is not used from non-kernel function scope or it is used from // global scope does not qualify for replacement. LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV); return LDSToNonKernels[GV].empty(); // FIXME: When GV is used within all (or within most of the kernels), then // it does not make sense to create a pointer for it. } // Insert new global LDS pointer which points to LDS. GlobalVariable *createLDSPointer(GlobalVariable *GV) { // LDS pointer which points to LDS is already created? return it. auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr)); if (!PointerEntry.second) return PointerEntry.first->second; // We need to create new LDS pointer which points to LDS. // // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address. auto *I16Ty = Type::getInt16Ty(Ctx); GlobalVariable *LDSPointer = new GlobalVariable( M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty), GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS); LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer)); // Mark that an associated LDS pointer is created for LDS. LDSToPointer[GV] = LDSPointer; return LDSPointer; } // Split entry basic block in such a way that only lane 0 of each wave does // the LDS pointer initialization, and return newly created basic block. BasicBlock *activateLaneZero(Function *K) { // If the entry basic block of kernel K is already splitted, then return // newly created basic block. auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr)); if (!BasicBlockEntry.second) return BasicBlockEntry.first->second; // Split entry basic block of kernel K. auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt())); IRBuilder<> Builder(EI); Value *Mbcnt = Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {}, {Builder.getInt32(-1), Builder.getInt32(0)}); Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0)); Instruction *WB = cast( Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {})); BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent(); // Mark that the entry basic block of kernel K is splitted. KernelToInitBB[K] = NBB; return NBB; } // Within given kernel, initialize given LDS pointer to point to given LDS. void initializeLDSPointer(Function *K, GlobalVariable *GV, GlobalVariable *LDSPointer) { // If LDS pointer is already initialized within K, then nothing to do. auto PointerEntry = KernelToLDSPointers.insert( std::make_pair(K, SmallPtrSet())); if (!PointerEntry.second) if (PointerEntry.first->second.contains(LDSPointer)) return; // Insert instructions at EI which initialize LDS pointer to point-to LDS // within kernel K. // // That is, convert pointer type of GV to i16, and then store this converted // i16 value within LDSPointer which is of type i16*. auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt())); IRBuilder<> Builder(EI); Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)), LDSPointer); // Mark that LDS pointer is initialized within kernel K. KernelToLDSPointers[K].insert(LDSPointer); } // We have created an LDS pointer for LDS, and initialized it to point-to LDS // within all relevent kernels. Now replace all the uses of LDS within // non-kernel functions by LDS pointer. void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) { SmallVector LDSUsers(GV->users()); for (auto *U : LDSUsers) { // When `U` is a constant expression, it is possible that same constant // expression exists within multiple instructions, and within multiple // non-kernel functions. Collect all those non-kernel functions and all // those instructions within which `U` exist. auto FunctionToInsts = AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/); for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end(); FI != FE; ++FI) { Function *F = FI->first; auto &Insts = FI->second; for (auto *I : Insts) { // If `U` is a constant expression, then we need to break the // associated instruction into a set of separate instructions by // converting constant expressions into instructions. SmallPtrSet UserInsts; if (U == I) { // `U` is an instruction, conversion from constant expression to // set of instructions is *not* required. UserInsts.insert(I); } else { // `U` is a constant expression, convert it into corresponding set // of instructions. auto *CE = cast(U); convertConstantExprsToInstructions(I, CE, &UserInsts); } // Go through all the user instrutions, if LDS exist within them as an // operand, then replace it by replace instruction. for (auto *II : UserInsts) { auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer); II->replaceUsesOfWith(GV, ReplaceInst); } } } } } // Create a set of replacement instructions which together replace LDS within // non-kernel function F by accessing LDS indirectly using LDS pointer. Value *getReplacementInst(Function *F, GlobalVariable *GV, GlobalVariable *LDSPointer) { // If the instruction which replaces LDS within F is already created, then // return it. auto LDSEntry = FunctionToLDSToReplaceInst.insert( std::make_pair(F, DenseMap())); if (!LDSEntry.second) { auto ReplaceInstEntry = LDSEntry.first->second.insert(std::make_pair(GV, nullptr)); if (!ReplaceInstEntry.second) return ReplaceInstEntry.first->second; } // Get the instruction insertion point within the beginning of the entry // block of current non-kernel function. auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt())); IRBuilder<> Builder(EI); // Insert required set of instructions which replace LDS within F. auto *V = Builder.CreateBitCast( Builder.CreateGEP( Builder.getInt8Ty(), LDSMemBaseAddr, Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)), GV->getType()); // Mark that the replacement instruction which replace LDS within F is // created. FunctionToLDSToReplaceInst[F][GV] = V; return V; } public: ReplaceLDSUseImpl(Module &M) : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) { LDSMemBaseAddr = Constant::getIntegerValue( PointerType::get(Type::getInt8Ty(M.getContext()), AMDGPUAS::LOCAL_ADDRESS), APInt(32, 0)); } // Entry-point function which interface ReplaceLDSUseImpl with outside of the // class. bool replaceLDSUse(); private: // For a given LDS from collected LDS globals set, replace its non-kernel // function scope uses by pointer. bool replaceLDSUse(GlobalVariable *GV); }; // For given LDS from collected LDS globals set, replace its non-kernel function // scope uses by pointer. bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) { // Holds all those non-kernel functions within which LDS is being accessed. SmallPtrSet &LDSAccessors = LDSToNonKernels[GV]; // The LDS pointer which points to LDS and replaces all the uses of LDS. GlobalVariable *LDSPointer = nullptr; // Traverse through each kernel K, check and if required, initialize the // LDS pointer to point to LDS within K. for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE; ++KI) { Function *K = KI->first; SmallPtrSet Callees = KI->second; // Compute reachable and LDS used callees for kernel K. set_intersect(Callees, LDSAccessors); // None of the LDS accessing non-kernel functions are reachable from // kernel K. Hence, no need to initialize LDS pointer within kernel K. if (Callees.empty()) continue; // We have found reachable and LDS used callees for kernel K, and we need to // initialize LDS pointer within kernel K, and we need to replace LDS use // within those callees by LDS pointer. // // But, first check if LDS pointer is already created, if not create one. LDSPointer = createLDSPointer(GV); // Initialize LDS pointer to point to LDS within kernel K. initializeLDSPointer(K, GV, LDSPointer); } // We have not found reachable and LDS used callees for any of the kernels, // and hence we have not created LDS pointer. if (!LDSPointer) return false; // We have created an LDS pointer for LDS, and initialized it to point-to LDS // within all relevent kernels. Now replace all the uses of LDS within // non-kernel functions by LDS pointer. replaceLDSUseByPointer(GV, LDSPointer); return true; } // Entry-point function which interface ReplaceLDSUseImpl with outside of the // class. bool ReplaceLDSUseImpl::replaceLDSUse() { // Collect LDS which requires their uses to be replaced by pointer. std::vector LDSGlobals = collectLDSRequiringPointerReplace(); // No LDS to pointer-replace. Nothing to do. if (LDSGlobals.empty()) return false; // Collect reachable callee set for each kernel defined in the module. AMDGPU::collectReachableCallees(M, KernelToCallees); if (KernelToCallees.empty()) { // Either module does not have any kernel definitions, or none of the kernel // has a call to non-kernel functions, or we could not resolve any of the // call sites to proper non-kernel functions, because of the situations like // inline asm calls. Nothing to replace. return false; } // For every LDS from collected LDS globals set, replace its non-kernel // function scope use by pointer. bool Changed = false; for (auto *GV : LDSGlobals) Changed |= replaceLDSUse(GV); return Changed; } class AMDGPUReplaceLDSUseWithPointer : public ModulePass { public: static char ID; AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) { initializeAMDGPUReplaceLDSUseWithPointerPass( *PassRegistry::getPassRegistry()); } bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); } }; } // namespace char AMDGPUReplaceLDSUseWithPointer::ID = 0; char &llvm::AMDGPUReplaceLDSUseWithPointerID = AMDGPUReplaceLDSUseWithPointer::ID; INITIALIZE_PASS_BEGIN( AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, "Replace within non-kernel function use of LDS with pointer", false /*only look at the cfg*/, false /*analysis pass*/) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) INITIALIZE_PASS_END( AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE, "Replace within non-kernel function use of LDS with pointer", false /*only look at the cfg*/, false /*analysis pass*/) bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) { ReplaceLDSUseImpl LDSUseReplacer{M}; return LDSUseReplacer.replaceLDSUse(); } ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() { return new AMDGPUReplaceLDSUseWithPointer(); } PreservedAnalyses AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) { ReplaceLDSUseImpl LDSUseReplacer{M}; LDSUseReplacer.replaceLDSUse(); return PreservedAnalyses::all(); }