From f68a31778f2b7c503d583c4645de00d66cb82334 Mon Sep 17 00:00:00 2001 From: Sergey Dmitriev Date: Thu, 30 Apr 2020 14:08:35 -0700 Subject: [PATCH] [AbstractCallSite] Look though constant cast expression when checking for callee use Summary: That makes AbstractCallSite::isCallee(const Use *) behavior consistent with AbstractCallSite constructor. Reviewers: jdoerfert Reviewed By: jdoerfert Subscribers: mgorny, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D79188 --- include/llvm/IR/AbstractCallSite.h | 6 +++ unittests/IR/AbstractCallSiteTest.cpp | 55 +++++++++++++++++++++++++++ unittests/IR/CMakeLists.txt | 1 + 3 files changed, 62 insertions(+) create mode 100644 unittests/IR/AbstractCallSiteTest.cpp diff --git a/include/llvm/IR/AbstractCallSite.h b/include/llvm/IR/AbstractCallSite.h index 18c0db8fa49..559da857851 100644 --- a/include/llvm/IR/AbstractCallSite.h +++ b/include/llvm/IR/AbstractCallSite.h @@ -141,6 +141,12 @@ public: assert(!CI.ParameterEncoding.empty() && "Callback without parameter encoding!"); + // If the use is actually in a constant cast expression which itself + // has only one use, we look through the constant cast expression. + if (auto *CE = dyn_cast(U->getUser())) + if (CE->getNumUses() == 1 && CE->isCast()) + U = &*CE->use_begin(); + return (int)CB->getArgOperandNo(U) == CI.ParameterEncoding[0]; } diff --git a/unittests/IR/AbstractCallSiteTest.cpp b/unittests/IR/AbstractCallSiteTest.cpp new file mode 100644 index 00000000000..ddb10911ad0 --- /dev/null +++ b/unittests/IR/AbstractCallSiteTest.cpp @@ -0,0 +1,55 @@ +//===----- AbstractCallSiteTest.cpp - AbstractCallSite Unittests ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/AbstractCallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("AbstractCallSiteTests", errs()); + return Mod; +} + +TEST(AbstractCallSite, CallbackCall) { + LLVMContext C; + + const char *IR = + "define void @callback(i8* %X, i32* %A) {\n" + " ret void\n" + "}\n" + "define void @foo(i32* %A) {\n" + " call void (i32, void (i8*, ...)*, ...) @broker(i32 1, void (i8*, ...)* bitcast (void (i8*, i32*)* @callback to void (i8*, ...)*), i32* %A)\n" + " ret void\n" + "}\n" + "declare !callback !0 void @broker(i32, void (i8*, ...)*, ...)\n" + "!0 = !{!1}\n" + "!1 = !{i64 1, i64 -1, i1 true}"; + + std::unique_ptr M = parseIR(C, IR); + ASSERT_TRUE(M); + + Function *Callback = M->getFunction("callback"); + ASSERT_NE(Callback, nullptr); + + const Use *CallbackUse = Callback->getSingleUndroppableUse(); + ASSERT_NE(CallbackUse, nullptr); + + AbstractCallSite ACS(CallbackUse); + EXPECT_TRUE(ACS); + EXPECT_TRUE(ACS.isCallbackCall()); + EXPECT_TRUE(ACS.isCallee(CallbackUse)); + EXPECT_EQ(ACS.getCalledFunction(), Callback); +} diff --git a/unittests/IR/CMakeLists.txt b/unittests/IR/CMakeLists.txt index 4241851dfad..4634bf89059 100644 --- a/unittests/IR/CMakeLists.txt +++ b/unittests/IR/CMakeLists.txt @@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS ) add_llvm_unittest(IRTests + AbstractCallSiteTest.cpp AsmWriterTest.cpp AttributesTest.cpp BasicBlockTest.cpp