diff --git a/llvm/docs/Coroutines.rst b/llvm/docs/Coroutines.rst index 60e32dc467d27..f64029547e648 100644 --- a/llvm/docs/Coroutines.rst +++ b/llvm/docs/Coroutines.rst @@ -2121,10 +2121,11 @@ Coroutine Transformation Passes =============================== CoroEarly --------- -The pass CoroEarly lowers coroutine intrinsics that hide the details of the -structure of the coroutine frame, but, otherwise not needed to be preserved to -help later coroutine passes. This pass lowers `coro.frame`_, `coro.done`_, -and `coro.promise`_ intrinsics. +The CoroEarly pass ensures later middle end passes correctly interpret coroutine +semantics and lowers coroutine intrinsics that not needed to be preserved to +help later coroutine passes. This pass lowers `coro.promise`_, `coro.frame`_ and +`coro.done`_ intrinsics. Afterwards, it replace uses of promise alloca with +`coro.promise`_ intrinsic. .. _CoroSplit: diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroShape.h b/llvm/include/llvm/Transforms/Coroutines/CoroShape.h index ea93ced1ce29e..891774b446571 100644 --- a/llvm/include/llvm/Transforms/Coroutines/CoroShape.h +++ b/llvm/include/llvm/Transforms/Coroutines/CoroShape.h @@ -79,7 +79,8 @@ struct Shape { // Scan the function and collect the above intrinsics for later processing void analyze(Function &F, SmallVectorImpl &CoroFrames, - SmallVectorImpl &UnusedCoroSaves); + SmallVectorImpl &UnusedCoroSaves, + CoroPromiseInst *&CoroPromise); // If for some reason, we were not able to find coro.begin, bailout. void invalidateCoroutine(Function &F, SmallVectorImpl &CoroFrames); @@ -87,7 +88,8 @@ struct Shape { void initABI(); // Remove orphaned and unnecessary intrinsics void cleanCoroutine(SmallVectorImpl &CoroFrames, - SmallVectorImpl &UnusedCoroSaves); + SmallVectorImpl &UnusedCoroSaves, + CoroPromiseInst *CoroPromise); // Field indexes for special fields in the switch lowering. struct SwitchFieldIndex { @@ -265,13 +267,14 @@ struct Shape { explicit Shape(Function &F) { SmallVector CoroFrames; SmallVector UnusedCoroSaves; + CoroPromiseInst *CoroPromise = nullptr; - analyze(F, CoroFrames, UnusedCoroSaves); + analyze(F, CoroFrames, UnusedCoroSaves, CoroPromise); if (!CoroBegin) { invalidateCoroutine(F, CoroFrames); return; } - cleanCoroutine(CoroFrames, UnusedCoroSaves); + cleanCoroutine(CoroFrames, UnusedCoroSaves, CoroPromise); } }; diff --git a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp index 5375448d2d2e2..eea6dfba14e37 100644 --- a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -30,6 +30,7 @@ class Lowerer : public coro::LowererBase { void lowerCoroPromise(CoroPromiseInst *Intrin); void lowerCoroDone(IntrinsicInst *II); void lowerCoroNoop(IntrinsicInst *II); + void hidePromiseAlloca(CoroIdInst *CoroId, CoroBeginInst *CoroBegin); public: Lowerer(Module &M) @@ -153,6 +154,28 @@ void Lowerer::lowerCoroNoop(IntrinsicInst *II) { II->eraseFromParent(); } +// Later middle-end passes will assume promise alloca dead after coroutine +// suspend, leading to misoptimizations. We hide promise alloca using +// coro.promise and will lower it back to alloca at CoroSplit. +void Lowerer::hidePromiseAlloca(CoroIdInst *CoroId, CoroBeginInst *CoroBegin) { + auto *PA = CoroId->getPromise(); + if (!PA || !CoroBegin) + return; + Builder.SetInsertPoint(*CoroBegin->getInsertionPointAfterDef()); + + auto *Alignment = Builder.getInt32(PA->getAlign().value()); + auto *FromPromise = Builder.getInt1(false); + SmallVector Arg{CoroBegin, Alignment, FromPromise}; + auto *PI = Builder.CreateIntrinsic( + Builder.getPtrTy(), Intrinsic::coro_promise, Arg, {}, "promise.addr"); + PI->setCannotDuplicate(); + PA->replaceUsesWithIf(PI, [CoroId](Use &U) { + bool IsBitcast = U == U.getUser()->stripPointerCasts(); + bool IsCoroId = U.getUser() == CoroId; + return !IsBitcast && !IsCoroId; + }); +} + // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate, // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit, // NoDuplicate attribute will be removed from coro.begin otherwise, it will @@ -165,6 +188,7 @@ static void setCannotDuplicate(CoroIdInst *CoroId) { void Lowerer::lowerEarlyIntrinsics(Function &F) { CoroIdInst *CoroId = nullptr; + CoroBeginInst *CoroBegin = nullptr; SmallVector CoroFrees; bool HasCoroSuspend = false; for (Instruction &I : llvm::make_early_inc_range(instructions(F))) { @@ -175,6 +199,13 @@ void Lowerer::lowerEarlyIntrinsics(Function &F) { switch (CB->getIntrinsicID()) { default: continue; + case Intrinsic::coro_begin: + case Intrinsic::coro_begin_custom_abi: + if (CoroBegin) + report_fatal_error( + "coroutine should have exactly one defining @llvm.coro.begin"); + CoroBegin = cast(&I); + break; case Intrinsic::coro_free: CoroFrees.push_back(cast(&I)); break; @@ -227,13 +258,16 @@ void Lowerer::lowerEarlyIntrinsics(Function &F) { } } - // Make sure that all CoroFree reference the coro.id intrinsic. - // Token type is not exposed through coroutine C/C++ builtins to plain C, so - // we allow specifying none and fixing it up here. - if (CoroId) + if (CoroId) { + // Make sure that all CoroFree reference the coro.id intrinsic. + // Token type is not exposed through coroutine C/C++ builtins to plain C, so + // we allow specifying none and fixing it up here. for (CoroFreeInst *CF : CoroFrees) CF->setArgOperand(0, CoroId); + hidePromiseAlloca(CoroId, CoroBegin); + } + // Coroutine suspention could potentially lead to any argument modified // outside of the function, hence arguments should not have noalias // attributes. diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 7b59c39283ded..02500ff778b80 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -192,7 +192,8 @@ static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin, // Collect "interesting" coroutine intrinsics. void coro::Shape::analyze(Function &F, SmallVectorImpl &CoroFrames, - SmallVectorImpl &UnusedCoroSaves) { + SmallVectorImpl &UnusedCoroSaves, + CoroPromiseInst *&CoroPromise) { clear(); bool HasFinalSuspend = false; @@ -286,6 +287,11 @@ void coro::Shape::analyze(Function &F, } } break; + case Intrinsic::coro_promise: + assert(CoroPromise == nullptr && + "CoroEarly must ensure coro.promise unique"); + CoroPromise = cast(II); + break; } } } @@ -477,7 +483,7 @@ void coro::AnyRetconABI::init() { void coro::Shape::cleanCoroutine( SmallVectorImpl &CoroFrames, - SmallVectorImpl &UnusedCoroSaves) { + SmallVectorImpl &UnusedCoroSaves, CoroPromiseInst *PI) { // The coro.frame intrinsic is always lowered to the result of coro.begin. for (CoroFrameInst *CF : CoroFrames) { CF->replaceAllUsesWith(CoroBegin); @@ -489,6 +495,13 @@ void coro::Shape::cleanCoroutine( for (CoroSaveInst *CoroSave : UnusedCoroSaves) CoroSave->eraseFromParent(); UnusedCoroSaves.clear(); + + if (PI) { + PI->replaceAllUsesWith(PI->isFromPromise() + ? cast(CoroBegin) + : cast(getPromiseAlloca())); + PI->eraseFromParent(); + } } static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) { diff --git a/llvm/test/Transforms/Coroutines/gh105595.ll b/llvm/test/Transforms/Coroutines/gh105595.ll new file mode 100644 index 0000000000000..0efe21216e998 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/gh105595.ll @@ -0,0 +1,31 @@ +; Test that store-load operation that crosses suspension point will not be eliminated by DSE +; Coro result conversion function that attempts to modify promise shall produce this pattern +; RUN: opt < %s -passes='coro-early,dse' -S | FileCheck %s + +define void @fn() presplitcoroutine { + %__promise = alloca ptr, align 8 + %id = call token @llvm.coro.id(i32 16, ptr %__promise, ptr @fn, ptr null) + %hdl = call ptr @llvm.coro.begin(token %id, ptr null) +; CHECK: %promise.addr = call ptr @llvm.coro.promise(ptr %hdl, i32 8, i1 false) + %save = call token @llvm.coro.save(ptr null) + %sp = call i8 @llvm.coro.suspend(token %save, i1 false) + %flag = icmp ule i8 %sp, 1 + br i1 %flag, label %resume, label %suspend + +resume: +; CHECK: call void @use_value(ptr %promise.addr) + call void @use_value(ptr %__promise) + br label %suspend + +suspend: +; load value when resume +; CHECK: %null = load ptr, ptr %promise.addr, align 8 + %null = load ptr, ptr %__promise, align 8 + call void @use_value(ptr %null) +; store value when suspend +; CHECK: store ptr null, ptr %promise.addr, align 8 + store ptr null, ptr %__promise, align 8 + ret void +} + +declare void @use_value(ptr)