From 5ae8f2dc2d4d4f51d24b2d19b51fcffee08d82b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= Date: Wed, 15 May 2024 17:10:13 +0200 Subject: [PATCH 1/5] [SPIR-V] Add pass to merge convergence region exit targets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The structurizer required regions to be SESE: single entry, single exit. This new pass transforms multiple-exit regions into single-exit regions. ``` +---+ | A | +---+ / \ +---+ +---+ | B | | C | A, B & C belongs to the same convergence region. +---+ +---+ | | +---+ +---+ | D | | E | C & D belongs to the parent convergence region. +---+ +---+ This means B & C are the exit blocks of the region. \ / And D & E the targets of those exits. \ / | +---+ | F | +---+ ``` This pass would assign one value per exit target: B = 0 C = 1 Then, create one variable per exit block (B, C), and assign it to the correct value: in B, the variable will have the value 0, and in C, the value 1. Then, we'd create a new block H, with a PHI node to gather those 2 variables, and a switch, to route to the correct target. Finally, the branches in B and C are updated to exit to this new block. ``` +---+ | A | +---+ / \ +---+ +---+ | B | | C | +---+ +---+ \ / +---+ | H | +---+ / \ +---+ +---+ | D | | E | +---+ +---+ \ / \ / | +---+ | F | +---+ ``` Note: the variable is set depending on the condition used to branch. If B's terminator was conditional, the variable would be set using a SELECT. All internal edges of a region are left intact, only exiting edges are updated. Signed-off-by: Nathan Gauër --- llvm/lib/Target/SPIRV/CMakeLists.txt | 1 + llvm/lib/Target/SPIRV/SPIRV.h | 1 + llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 23 ++ llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 2 +- .../SPIRV/SPIRVMergeRegionExitTargets.cpp | 290 ++++++++++++++++++ llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp | 1 + .../SPIRV/structurizer/merge-exit-break.ll | 84 +++++ .../merge-exit-convergence-in-break.ll | 94 ++++++ .../structurizer/merge-exit-multiple-break.ll | 103 +++++++ .../merge-exit-simple-white-identity.ll | 49 +++ 10 files changed, 647 insertions(+), 1 deletion(-) create mode 100644 llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll create mode 100644 llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index 7001ac382f41c..35a463a89ec64 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -24,6 +24,7 @@ add_llvm_target(SPIRVCodeGen SPIRVInstrInfo.cpp SPIRVInstructionSelector.cpp SPIRVStripConvergentIntrinsics.cpp + SPIRVMergeRegionExitTargets.cpp SPIRVISelLowering.cpp SPIRVLegalizerInfo.cpp SPIRVMCInstLower.cpp diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index fb8580cd47c01..e597a1dc8dc06 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -20,6 +20,7 @@ class InstructionSelector; class RegisterBankInfo; ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM); +FunctionPass *createSPIRVMergeRegionExitTargetsPass(); FunctionPass *createSPIRVStripConvergenceIntrinsicsPass(); FunctionPass *createSPIRVRegularizerPass(); FunctionPass *createSPIRVPreLegalizerPass(); diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 32df2403dfe52..057cdd7a3ee2c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -150,6 +150,16 @@ class SPIRVEmitIntrinsics ModulePass::getAnalysisUsage(AU); } }; + +bool isConvergenceIntrinsic(const Instruction *I) { + const auto *II = dyn_cast(I); + if (!II) + return false; + + return II->getIntrinsicID() == Intrinsic::experimental_convergence_entry || + II->getIntrinsicID() == Intrinsic::experimental_convergence_loop || + II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor; +} } // namespace char SPIRVEmitIntrinsics::ID = 0; @@ -1074,6 +1084,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B) { + // Don't assign types to LLVM tokens. + if (isConvergenceIntrinsic(I)) + return; + reportFatalOnTokenType(I); if (!isPointerTy(I->getType()) || !requireAssignType(I) || isa(I)) @@ -1092,6 +1106,10 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B) { + // Don't assign types to LLVM tokens. + if (isConvergenceIntrinsic(I)) + return; + reportFatalOnTokenType(I); Type *Ty = I->getType(); if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) { @@ -1319,6 +1337,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { I = visit(*I); if (!I) continue; + + // Don't emit intrinsics for convergence operations. + if (isConvergenceIntrinsic(I)) + continue; + processInstrAfterVisit(I, B); } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 151d0ec1fe569..6634481daf12e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -615,7 +615,7 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>; def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops), "$res = OpPhi $type $var0 $block0">; def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops), - "OpLoopMerge $merge $merge $continue $lc">; + "OpLoopMerge $merge $continue $lc">; def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc), "OpSelectionMerge $merge $sc">; def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">; diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp new file mode 100644 index 0000000000000..13781e24f0d42 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp @@ -0,0 +1,290 @@ +//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Merge the multiple exit targets of a convergence region into a single block. +// Each exit target will be assigned a constant value, and a phi node + switch +// will allow the new exit target to re-route to the correct basic block. +// +//===----------------------------------------------------------------------===// + +#include "Analysis/SPIRVConvergenceRegionAnalysis.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/CodeGen/IntrinsicLowering.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LowerMemIntrinsics.h" + +using namespace llvm; + +namespace llvm { +void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &); +} // namespace llvm + +namespace llvm { + +class SPIRVMergeRegionExitTargets : public FunctionPass { +public: + static char ID; + + SPIRVMergeRegionExitTargets() : FunctionPass(ID) { + initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry()); + }; + + // Gather all the successors of |BB|. + // This function asserts if the terminator neither a branch, switch or return. + std::unordered_set gatherSuccessors(BasicBlock *BB) { + std::unordered_set output; + auto *T = BB->getTerminator(); + + if (auto *BI = dyn_cast(T)) { + output.insert(BI->getSuccessor(0)); + if (BI->isConditional()) + output.insert(BI->getSuccessor(1)); + return output; + } + + if (auto *SI = dyn_cast(T)) { + output.insert(SI->getDefaultDest()); + for (auto &Case : SI->cases()) { + output.insert(Case.getCaseSuccessor()); + } + return output; + } + + if (auto *RI = dyn_cast(T)) + return output; + + assert(false && "Unhandled terminator type."); + return output; + } + + /// Create a value in BB set to the value associated with the branch the block + /// terminator will take. + llvm::Value *createExitVariable( + BasicBlock *BB, + const std::unordered_map &TargetToValue) { + auto *T = BB->getTerminator(); + if (auto *RI = dyn_cast(T)) { + return nullptr; + } + + IRBuilder<> Builder(BB); + Builder.SetInsertPoint(T); + + if (auto *BI = dyn_cast(T)) { + + BasicBlock *LHSTarget = BI->getSuccessor(0); + BasicBlock *RHSTarget = + BI->isConditional() ? BI->getSuccessor(1) : nullptr; + + Value *LHS = TargetToValue.count(LHSTarget) != 0 + ? TargetToValue.at(LHSTarget) + : nullptr; + Value *RHS = TargetToValue.count(RHSTarget) != 0 + ? TargetToValue.at(RHSTarget) + : nullptr; + + if (LHS == nullptr || RHS == nullptr) + return LHS == nullptr ? RHS : LHS; + return Builder.CreateSelect(BI->getCondition(), LHS, RHS); + } + + // TODO: add support for switch cases. + assert(false && "Unhandled terminator type."); + } + + /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. + void replaceBranchTargets(BasicBlock *BB, + const std::unordered_set ToReplace, + BasicBlock *NewTarget) { + auto *T = BB->getTerminator(); + if (auto *RI = dyn_cast(T)) + return; + + if (auto *BI = dyn_cast(T)) { + for (size_t i = 0; i < BI->getNumSuccessors(); i++) { + if (ToReplace.count(BI->getSuccessor(i)) != 0) + BI->setSuccessor(i, NewTarget); + } + return; + } + + if (auto *SI = dyn_cast(T)) { + for (size_t i = 0; i < SI->getNumSuccessors(); i++) { + if (ToReplace.count(SI->getSuccessor(i)) != 0) + SI->setSuccessor(i, NewTarget); + } + return; + } + + assert(false && "Unhandled terminator type."); + } + + // Run the pass on the given convergence region, ignoring the sub-regions. + // Returns true if the CFG changed, false otherwise. + bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, + const SPIRV::ConvergenceRegion *CR) { + // Gather all the exit targets for this region. + std::unordered_set ExitTargets; + for (BasicBlock *Exit : CR->Exits) { + for (BasicBlock *Target : gatherSuccessors(Exit)) { + if (CR->Blocks.count(Target) == 0) + ExitTargets.insert(Target); + } + } + + // If we have zero or one exit target, nothing do to. + if (ExitTargets.size() <= 1) + return false; + + // Create the new single exit target. + auto F = CR->Entry->getParent(); + auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F); + IRBuilder<> Builder(NewExitTarget); + + // CodeGen output needs to be stable. Using the set as-is would order + // the targets differently depending on the allocation pattern. + // Sorting per basic-block ordering in the function. + std::vector SortedExitTargets; + std::vector SortedExits; + for (BasicBlock &BB : *F) { + if (ExitTargets.count(&BB) != 0) + SortedExitTargets.push_back(&BB); + if (CR->Exits.count(&BB) != 0) + SortedExits.push_back(&BB); + } + + // Creating one constant per distinct exit target. This will be route to the + // correct target. + std::unordered_map TargetToValue; + for (BasicBlock *Target : SortedExitTargets) + TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size())); + + // Creating one variable per exit node, set to the constant matching the + // targeted external block. + std::vector> ExitToVariable; + for (auto Exit : SortedExits) { + llvm::Value *Value = createExitVariable(Exit, TargetToValue); + ExitToVariable.emplace_back(std::make_pair(Exit, Value)); + } + + // Gather the correct value depending on the exit we came from. + llvm::PHINode *node = + Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size()); + for (auto [BB, Value] : ExitToVariable) { + node->addIncoming(Value, BB); + } + + // Creating the switch to jump to the correct exit target. + std::vector> CasesList( + TargetToValue.begin(), TargetToValue.end()); + llvm::SwitchInst *Sw = + Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1); + for (size_t i = 1; i < CasesList.size(); i++) + Sw->addCase(CasesList[i].second, CasesList[i].first); + + // Fix exit branches to redirect to the new exit. + for (auto Exit : CR->Exits) + replaceBranchTargets(Exit, ExitTargets, NewExitTarget); + + return true; + } + + /// Run the pass on the given convergence region and sub-regions (DFS). + /// Returns true if a region/sub-region was modified, false otherwise. + /// This returns as soon as one region/sub-region has been modified. + bool runOnConvergenceRegion(LoopInfo &LI, + const SPIRV::ConvergenceRegion *CR) { + for (auto *Child : CR->Children) + if (runOnConvergenceRegion(LI, Child)) + return true; + + return runOnConvergenceRegionNoRecurse(LI, CR); + } + +#if !NDEBUG + /// Validates each edge exiting the region has the same destination basic + /// block. + void validateRegionExits(const SPIRV::ConvergenceRegion *CR) { + for (auto *Child : CR->Children) + validateRegionExits(Child); + + std::unordered_set ExitTargets; + for (auto *Exit : CR->Exits) { + auto Set = gatherSuccessors(Exit); + for (auto *BB : Set) { + if (CR->Blocks.count(BB) == 0) + ExitTargets.insert(BB); + } + } + + assert(ExitTargets.size() <= 1); + } +#endif + + virtual bool runOnFunction(Function &F) override { + LoopInfo &LI = getAnalysis().getLoopInfo(); + const auto *TopLevelRegion = + getAnalysis() + .getRegionInfo() + .getTopLevelRegion(); + + // FIXME: very inefficient method: each time a region is modified, we bubble + // back up, and recompute the whole convergence region tree. Once the + // algorithm is completed and test coverage good enough, rewrite this pass + // to be efficient instead of simple. + bool modified = false; + while (runOnConvergenceRegion(LI, TopLevelRegion)) { + TopLevelRegion = getAnalysis() + .getRegionInfo() + .getTopLevelRegion(); + modified = true; + } + + F.dump(); +#if !NDEBUG + validateRegionExits(TopLevelRegion); +#endif + return modified; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + FunctionPass::getAnalysisUsage(AU); + } +}; +} // namespace llvm + +char SPIRVMergeRegionExitTargets::ID = 0; + +INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", + "SPIRV split region exit blocks", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass) + +INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks", + "SPIRV split region exit blocks", false, false) + +FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() { + return new SPIRVMergeRegionExitTargets(); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index ae8baa3f11913..d0e51caf46e73 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -164,6 +164,7 @@ void SPIRVPassConfig::addIRPasses() { // - all loop exits are dominated by the loop pre-header. // - loops have a single back-edge. addPass(createLoopSimplifyPass()); + addPass(createSPIRVMergeRegionExitTargetsPass()); } TargetPassConfig::addIRPasses(); diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll new file mode 100644 index 0000000000000..b3fcdc978625f --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll @@ -0,0 +1,84 @@ +; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1" +target triple = "spirv-unknown-vulkan-compute" + +define internal spir_func void @main() #0 { + +; CHECK: OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId +; CHECK-DAG: %[[#int_ty:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]] +; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool +; CHECK-DAG: %[[#int_0:]] = OpConstant %[[#int_ty]] 0 +; CHECK-DAG: %[[#int_1:]] = OpConstant %[[#int_ty]] 1 +; CHECK-DAG: %[[#int_10:]] = OpConstant %[[#int_ty]] 10 + +; CHECK: %[[#entry:]] = OpLabel +; CHECK: %[[#idx:]] = OpVariable %[[#pint_ty]] Function +; CHECK: OpStore %[[#idx]] %[[#int_0]] Aligned 4 +; CHECK: OpBranch %[[#while_cond:]] +entry: + %0 = call token @llvm.experimental.convergence.entry() + %idx = alloca i32, align 4 + store i32 0, ptr %idx, align 4 + br label %while.cond + +; CHECK: %[[#while_cond]] = OpLabel +; CHECK: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4 +; CHECK: %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]] +; CHECK: OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]] +while.cond: + %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %2 = load i32, ptr %idx, align 4 + %cmp = icmp ne i32 %2, 10 + br i1 %cmp, label %while.body, label %while.end + +; CHECK: %[[#while_body]] = OpLabel +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1 +; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4 +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4 +; CHECK-NEXT: %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]] +; CHECK: OpBranchConditional %[[#cmp1]] %[[#new_end]] %[[#if_end:]] +while.body: + %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ] + store i32 %3, ptr %idx, align 4 + %4 = load i32, ptr %idx, align 4 + %cmp1 = icmp eq i32 %4, 0 + br i1 %cmp1, label %if.then, label %if.end + +; CHECK: %[[#if_then:]] = OpLabel +; CHECK: OpBranch %[[#while_end:]] +if.then: + br label %while.end + +; CHECK: %[[#if_end]] = OpLabel +; CHECK: OpBranch %[[#while_cond]] +if.end: + br label %while.cond + +; CHECK: %[[#while_end_loopexit:]] = OpLabel +; CHECK: OpBranch %[[#while_end]] + +; CHECK: %[[#while_end]] = OpLabel +; CHECK: OpReturn +while.end: + ret void + +; CHECK: %[[#new_end]] = OpLabel +; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_1]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] +; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 0 %[[#if_then]] +} + +declare token @llvm.experimental.convergence.entry() #2 +declare token @llvm.experimental.convergence.loop() #2 +declare i32 @__hlsl_wave_get_lane_index() #3 + +attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) } +attributes #3 = { convergent } + +!llvm.module.flags = !{!0, !1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 4, !"dx.disable_optimizations", i32 1} diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll new file mode 100644 index 0000000000000..a67c58fdd5749 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll @@ -0,0 +1,94 @@ +; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1" +target triple = "spirv-unknown-vulkan-compute" + +define internal spir_func void @main() #0 { + +; CHECK: OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId +; CHECK-DAG: %[[#int_ty:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]] +; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool +; CHECK-DAG: %[[#int_0:]] = OpConstant %[[#int_ty]] 0 +; CHECK-DAG: %[[#int_1:]] = OpConstant %[[#int_ty]] 1 +; CHECK-DAG: %[[#int_10:]] = OpConstant %[[#int_ty]] 10 + +; CHECK: %[[#entry:]] = OpLabel +; CHECK: %[[#idx:]] = OpVariable %[[#pint_ty]] Function +; CHECK: OpStore %[[#idx]] %[[#int_0]] Aligned 4 +; CHECK: OpBranch %[[#while_cond:]] +entry: + %0 = call token @llvm.experimental.convergence.entry() + %idx = alloca i32, align 4 + store i32 0, ptr %idx, align 4 + br label %while.cond + +; CHECK: %[[#while_cond]] = OpLabel +; CHECK: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4 +; CHECK: %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]] +; CHECK: OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]] +while.cond: + %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %2 = load i32, ptr %idx, align 4 + %cmp = icmp ne i32 %2, 10 + br i1 %cmp, label %while.body, label %while.end + +; CHECK: %[[#while_body]] = OpLabel +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1 +; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4 +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4 +; CHECK-NEXT: %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]] +; CHECK: OpBranchConditional %[[#cmp1]] %[[#if_then:]] %[[#if_end:]] +while.body: + %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ] + store i32 %3, ptr %idx, align 4 + %4 = load i32, ptr %idx, align 4 + %cmp1 = icmp eq i32 %4, 0 + br i1 %cmp1, label %if.then, label %if.end + +; CHECK: %[[#if_then:]] = OpLabel +; CHECK-NEXT: OpBranch %[[#tail:]] +if.then: + br label %tail + +; CHECK: %[[#tail:]] = OpLabel +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1 +; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4 +; CHECK: OpBranch %[[#new_end:]] +tail: + %5 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ] + store i32 %5, ptr %idx, align 4 + br label %while.end + +; CHECK: %[[#if_end]] = OpLabel +; CHECK: OpBranch %[[#while_cond]] +if.end: + br label %while.cond + +; CHECK: %[[#while_end_loopexit:]] = OpLabel +; CHECK: OpBranch %[[#while_end:]] + +; CHECK: %[[#while_end]] = OpLabel +; CHECK: OpReturn +while.end: + ret void + +; CHECK: %[[#new_end]] = OpLabel +; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_0]] %[[#while_cond]] %[[#int_1]] %[[#tail]] +; CHECK: OpSwitch %[[#route]] %[[#while_end]] 0 %[[#while_end_loopexit]] +} + +declare token @llvm.experimental.convergence.entry() #2 +declare token @llvm.experimental.convergence.loop() #2 +declare i32 @__hlsl_wave_get_lane_index() #3 + +attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) } +attributes #3 = { convergent } + +!llvm.module.flags = !{!0, !1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 4, !"dx.disable_optimizations", i32 1} + diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll new file mode 100644 index 0000000000000..32a97553df05e --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll @@ -0,0 +1,103 @@ +; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1" +target triple = "spirv-unknown-vulkan-compute" + +define internal spir_func void @main() #0 { + +; CHECK: OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId +; CHECK-DAG: %[[#int_ty:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]] +; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool +; CHECK-DAG: %[[#int_0:]] = OpConstant %[[#int_ty]] 0 +; CHECK-DAG: %[[#int_1:]] = OpConstant %[[#int_ty]] 1 +; CHECK-DAG: %[[#int_2:]] = OpConstant %[[#int_ty]] 2 +; CHECK-DAG: %[[#int_10:]] = OpConstant %[[#int_ty]] 10 + +; CHECK: %[[#entry:]] = OpLabel +; CHECK: %[[#idx:]] = OpVariable %[[#pint_ty]] Function +; CHECK: OpStore %[[#idx]] %[[#int_0]] Aligned 4 +; CHECK: OpBranch %[[#while_cond:]] +entry: + %0 = call token @llvm.experimental.convergence.entry() + %idx = alloca i32, align 4 + store i32 0, ptr %idx, align 4 + br label %while.cond + +; CHECK: %[[#while_cond]] = OpLabel +; CHECK: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4 +; CHECK: %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]] +; CHECK: OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]] +while.cond: + %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %2 = load i32, ptr %idx, align 4 + %cmp = icmp ne i32 %2, 10 + br i1 %cmp, label %while.body, label %while.end + +; CHECK: %[[#while_body]] = OpLabel +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1 +; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4 +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4 +; CHECK-NEXT: %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]] +; CHECK: OpBranchConditional %[[#cmp1]] %[[#new_end]] %[[#if_end:]] +while.body: + %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ] + store i32 %3, ptr %idx, align 4 + %4 = load i32, ptr %idx, align 4 + %cmp1 = icmp eq i32 %4, 0 + br i1 %cmp1, label %if.then, label %if.end + +; CHECK: %[[#if_then:]] = OpLabel +; CHECK: OpBranch %[[#while_end:]] +if.then: + br label %while.end + +; CHECK: %[[#if_end]] = OpLabel +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1 +; CHECK-NEXT: OpStore %[[#idx]] %[[#tmp]] Aligned 4 +; CHECK-NEXT: %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4 +; CHECK-NEXT: %[[#cmp2:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]] +; CHECK: OpBranchConditional %[[#cmp2]] %[[#new_end]] %[[#if_end2:]] +if.end: + %5 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ] + store i32 %5, ptr %idx, align 4 + %6 = load i32, ptr %idx, align 4 + %cmp2 = icmp eq i32 %6, 0 + br i1 %cmp2, label %if.then2, label %if.end2 + +; CHECK: %[[#if_then2:]] = OpLabel +; CHECK: OpBranch %[[#while_end:]] +if.then2: + br label %while.end + +; CHECK: %[[#if_end2]] = OpLabel +; CHECK: OpBranch %[[#while_cond:]] +if.end2: + br label %while.cond + +; CHECK: %[[#while_end_loopexit:]] = OpLabel +; CHECK: OpBranch %[[#while_end]] + +; CHECK: %[[#while_end]] = OpLabel +; CHECK: OpReturn +while.end: + ret void + +; CHECK: %[[#new_end]] = OpLabel +; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_2]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] %[[#int_1]] %[[#if_end]] +; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#if_then2]] 0 %[[#if_then]] +} + +declare token @llvm.experimental.convergence.entry() #2 +declare token @llvm.experimental.convergence.loop() #2 +declare i32 @__hlsl_wave_get_lane_index() #3 + +attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) } +attributes #3 = { convergent } + +!llvm.module.flags = !{!0, !1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 4, !"dx.disable_optimizations", i32 1} diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll new file mode 100644 index 0000000000000..a8bf4fb0db989 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll @@ -0,0 +1,49 @@ +; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1" +target triple = "spirv-unknown-vulkan-compute" + +define internal spir_func void @main() #0 { + +; CHECK: %[[#entry:]] = OpLabel +; CHECK: OpBranch %[[#while_cond:]] +entry: + %0 = call token @llvm.experimental.convergence.entry() + %idx = alloca i32, align 4 + store i32 -1, ptr %idx, align 4 + br label %while.cond + +; CHECK: %[[#while_cond]] = OpLabel +; CHECK: OpBranchConditional %[[#cond:]] %[[#while_body:]] %[[#while_end:]] +while.cond: + %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %2 = load i32, ptr %idx, align 4 + %cmp = icmp ne i32 %2, 0 + br i1 %cmp, label %while.body, label %while.end + +; CHECK: %[[#while_body]] = OpLabel +; CHECK: OpBranch %[[#while_cond]] +while.body: + %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ] + store i32 %3, ptr %idx, align 4 + br label %while.cond + + ; CHECK: %[[#while_end]] = OpLabel +; CHECK-NEXT: OpReturn +while.end: + ret void +} + +declare token @llvm.experimental.convergence.entry() #2 +declare token @llvm.experimental.convergence.loop() #2 +declare i32 @__hlsl_wave_get_lane_index() #3 + +attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) } +attributes #3 = { convergent } + +!llvm.module.flags = !{!0, !1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 4, !"dx.disable_optimizations", i32 1} From a44183bd5b84a5bb752ff394b0dfa55c1dce0860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= Date: Tue, 21 May 2024 11:28:34 +0200 Subject: [PATCH 2/5] pr-feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nathan Gauër --- llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 12 ++++-------- .../lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp | 6 ++---- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 057cdd7a3ee2c..052f114faae23 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -1084,10 +1084,6 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B) { - // Don't assign types to LLVM tokens. - if (isConvergenceIntrinsic(I)) - return; - reportFatalOnTokenType(I); if (!isPointerTy(I->getType()) || !requireAssignType(I) || isa(I)) @@ -1106,10 +1102,6 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B) { - // Don't assign types to LLVM tokens. - if (isConvergenceIntrinsic(I)) - return; - reportFatalOnTokenType(I); Type *Ty = I->getType(); if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) { @@ -1319,6 +1311,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { Worklist.push_back(&I); for (auto &I : Worklist) { + // Don't emit intrinsincs for convergence intrinsics. + if (isConvergenceIntrinsic(I)) + continue; + insertAssignPtrTypeIntrs(I, B); insertAssignTypeIntrs(I, B); insertPtrCastOrAssignTypeInstr(I, B); diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp index 13781e24f0d42..c04bbda155960 100644 --- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp @@ -61,9 +61,8 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { if (auto *SI = dyn_cast(T)) { output.insert(SI->getDefaultDest()); - for (auto &Case : SI->cases()) { + for (auto &Case : SI->cases()) output.insert(Case.getCaseSuccessor()); - } return output; } @@ -80,9 +79,8 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { BasicBlock *BB, const std::unordered_map &TargetToValue) { auto *T = BB->getTerminator(); - if (auto *RI = dyn_cast(T)) { + if (auto *RI = dyn_cast(T)) return nullptr; - } IRBuilder<> Builder(BB); Builder.SetInsertPoint(T); From afade8ee4233c163089e910ba87adda2c5386b69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= Date: Tue, 21 May 2024 14:02:35 +0200 Subject: [PATCH 3/5] remove F.dump() call --- llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp index c04bbda155960..dc2e0e41ea1c1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp @@ -255,7 +255,6 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { modified = true; } - F.dump(); #if !NDEBUG validateRegionExits(TopLevelRegion); #endif From 894a35fba5c0f07ac91a0d51b45d1b1d2110edee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= Date: Mon, 27 May 2024 17:46:00 +0200 Subject: [PATCH 4/5] pr feedback --- llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp | 2 +- llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp index dc2e0e41ea1c1..f1182cf0506b7 100644 --- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp @@ -255,7 +255,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { modified = true; } -#if !NDEBUG +#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) validateRegionExits(TopLevelRegion); #endif return modified; diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index d0e51caf46e73..a6823a8ba3230 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -164,6 +164,10 @@ void SPIRVPassConfig::addIRPasses() { // - all loop exits are dominated by the loop pre-header. // - loops have a single back-edge. addPass(createLoopSimplifyPass()); + + // 2. Merge the convergence region exit nodes into one. After this step, + // regions are single-entry, single-exit. This will help determine the + // correct merge block. addPass(createSPIRVMergeRegionExitTargetsPass()); } From b3f682e898684f47fbad07eb4906873cfb448e7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= Date: Mon, 27 May 2024 17:47:26 +0200 Subject: [PATCH 5/5] pr feedback, merge namespaces --- llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp index f1182cf0506b7..2cdeb32579038 100644 --- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp @@ -34,9 +34,6 @@ using namespace llvm; namespace llvm { void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &); -} // namespace llvm - -namespace llvm { class SPIRVMergeRegionExitTargets : public FunctionPass { public: