diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 436cdbff75669..b42f0ca296fc5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2643,13 +2643,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // (sign|resign) + (auth|resign) can be folded by omitting the middle // sign+auth component if the key and discriminator match. bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign; + Value *Ptr = II->getArgOperand(0); Value *Key = II->getArgOperand(1); Value *Disc = II->getArgOperand(2); // AuthKey will be the key we need to end up authenticating against in // whatever we replace this sequence with. Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr; - if (auto CI = dyn_cast(II->getArgOperand(0))) { + if (const auto *CI = dyn_cast(Ptr)) { BasePtr = CI->getArgOperand(0); if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) { if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc) @@ -2661,6 +2662,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { AuthDisc = CI->getArgOperand(2); } else break; + } else if (const auto *PtrToInt = dyn_cast(Ptr)) { + // ptrauth constants are equivalent to a call to @llvm.ptrauth.sign for + // our purposes, so check for that too. + const auto *CPA = dyn_cast(PtrToInt->getOperand(0)); + if (!CPA || !CPA->isKnownCompatibleWith(Key, Disc, DL)) + break; + + // resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr) + if (NeedSign && isa(II->getArgOperand(4))) { + auto *SignKey = cast(II->getArgOperand(3)); + auto *SignDisc = cast(II->getArgOperand(4)); + auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy()); + auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey, + SignDisc, SignAddrDisc); + replaceInstUsesWith( + *II, ConstantExpr::getPointerCast(NewCPA, II->getType())); + return eraseInstFromFunction(*II); + } + + // auth(ptrauth(p,k,d),k,d) -> p + BasePtr = Builder.CreatePtrToInt(CPA->getPointer(), II->getType()); } else break; @@ -2677,8 +2699,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } else { // sign(0) + auth(0) = nop replaceInstUsesWith(*II, BasePtr); - eraseInstFromFunction(*II); - return nullptr; + return eraseInstFromFunction(*II); } SmallVector CallArgs; diff --git a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll index da0f724abfde4..208e162ac9416 100644 --- a/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/ptrauth-intrinsics.ll @@ -12,6 +12,27 @@ define i64 @test_ptrauth_nop(ptr %p) { ret i64 %authed } +declare void @foo() +declare void @bar() + +define i64 @test_ptrauth_nop_constant() { +; CHECK-LABEL: @test_ptrauth_nop_constant( +; CHECK-NEXT: ret i64 ptrtoint (ptr @foo to i64) +; + %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234) + ret i64 %authed +} + +define i64 @test_ptrauth_nop_constant_addrdisc() { +; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc( +; CHECK-NEXT: ret i64 ptrtoint (ptr @foo to i64) +; + %addr = ptrtoint ptr @foo to i64 + %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234) + %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended) + ret i64 %authed +} + define i64 @test_ptrauth_nop_mismatch(ptr %p) { ; CHECK-LABEL: @test_ptrauth_nop_mismatch( ; CHECK-NEXT: [[TMP0:%.*]] = ptrtoint ptr [[P:%.*]] to i64 @@ -87,6 +108,59 @@ define i64 @test_ptrauth_resign_auth_mismatch(ptr %p) { ret i64 %authed } +define i64 @test_ptrauth_nop_constant_mismatch() { +; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch( +; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12) +; CHECK-NEXT: ret i64 [[AUTHED]] +; + %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 12) + ret i64 %authed +} + +define i64 @test_ptrauth_nop_constant_mismatch_key() { +; CHECK-LABEL: @test_ptrauth_nop_constant_mismatch_key( +; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234) +; CHECK-NEXT: ret i64 [[AUTHED]] +; + %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 0, i64 1234) + ret i64 %authed +} + +define i64 @test_ptrauth_nop_constant_addrdisc_mismatch() { +; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch( +; CHECK-NEXT: [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @foo to i64), i64 12) +; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]]) +; CHECK-NEXT: ret i64 [[AUTHED]] +; + %addr = ptrtoint ptr @foo to i64 + %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 12) + %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended) + ret i64 %authed +} + +define i64 @test_ptrauth_nop_constant_addrdisc_mismatch2() { +; CHECK-LABEL: @test_ptrauth_nop_constant_addrdisc_mismatch2( +; CHECK-NEXT: [[BLENDED:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @bar to i64), i64 1234) +; CHECK-NEXT: [[AUTHED:%.*]] = call i64 @llvm.ptrauth.auth(i64 ptrtoint (ptr ptrauth (ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 [[BLENDED]]) +; CHECK-NEXT: ret i64 [[AUTHED]] +; + %addr = ptrtoint ptr @bar to i64 + %blended = call i64 @llvm.ptrauth.blend(i64 %addr, i64 1234) + %authed = call i64 @llvm.ptrauth.auth(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234, ptr @foo) to i64), i32 1, i64 %blended) + ret i64 %authed +} + +define i64 @test_ptrauth_resign_ptrauth_constant(ptr %p) { +; CHECK-LABEL: @test_ptrauth_resign_ptrauth_constant( +; CHECK-NEXT: ret i64 ptrtoint (ptr ptrauth (ptr @foo, i32 0, i64 42) to i64) +; + + %tmp0 = ptrtoint ptr %p to i64 + %authed = call i64 @llvm.ptrauth.resign(i64 ptrtoint(ptr ptrauth(ptr @foo, i32 1, i64 1234) to i64), i32 1, i64 1234, i32 0, i64 42) + ret i64 %authed +} + declare i64 @llvm.ptrauth.auth(i64, i32, i64) declare i64 @llvm.ptrauth.sign(i64, i32, i64) declare i64 @llvm.ptrauth.resign(i64, i32, i64, i32, i64) +declare i64 @llvm.ptrauth.blend(i64, i64)