//===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/ADT/Triple.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/AsmParser/Parser.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/MachinePassManager.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Support/Host.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "gtest/gtest.h" using namespace llvm; namespace { class TestFunctionAnalysis : public AnalysisInfoMixin { public: struct Result { Result(int Count) : InstructionCount(Count) {} int InstructionCount; }; /// Run the analysis pass over the function and return a result. Result run(Function &F, FunctionAnalysisManager &AM) { int Count = 0; for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE; ++II) ++Count; return Result(Count); } private: friend AnalysisInfoMixin; static AnalysisKey Key; }; AnalysisKey TestFunctionAnalysis::Key; class TestMachineFunctionAnalysis : public AnalysisInfoMixin { public: struct Result { Result(int Count) : InstructionCount(Count) {} int InstructionCount; }; /// Run the analysis pass over the machine function and return a result. Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) { auto &MFAM = static_cast(AM); // Query function analysis result. TestFunctionAnalysis::Result &FAR = MFAM.getResult(MF.getFunction()); // + 5 return FAR.InstructionCount; } private: friend AnalysisInfoMixin; static AnalysisKey Key; }; AnalysisKey TestMachineFunctionAnalysis::Key; const std::string DoInitErrMsg = "doInitialization failed"; const std::string DoFinalErrMsg = "doFinalization failed"; struct TestMachineFunctionPass : public PassInfoMixin { TestMachineFunctionPass(int &Count, std::vector &BeforeInitialization, std::vector &BeforeFinalization, std::vector &MachineFunctionPassCount) : Count(Count), BeforeInitialization(BeforeInitialization), BeforeFinalization(BeforeFinalization), MachineFunctionPassCount(MachineFunctionPassCount) {} Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) { // Force doInitialization fail by starting with big `Count`. if (Count > 10000) return make_error(DoInitErrMsg, inconvertibleErrorCode()); // + 1 ++Count; BeforeInitialization.push_back(Count); return Error::success(); } Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) { // Force doFinalization fail by starting with big `Count`. if (Count > 1000) return make_error(DoFinalErrMsg, inconvertibleErrorCode()); // + 1 ++Count; BeforeFinalization.push_back(Count); return Error::success(); } PreservedAnalyses run(MachineFunction &MF, MachineFunctionAnalysisManager &MFAM) { // Query function analysis result. TestFunctionAnalysis::Result &FAR = MFAM.getResult(MF.getFunction()); // 3 + 1 + 1 = 5 Count += FAR.InstructionCount; // Query module analysis result. MachineModuleInfo &MMI = MFAM.getResult(*MF.getFunction().getParent()); // 1 + 1 + 1 = 3 Count += (MMI.getModule() == MF.getFunction().getParent()); // Query machine function analysis result. TestMachineFunctionAnalysis::Result &MFAR = MFAM.getResult(MF); // 3 + 1 + 1 = 5 Count += MFAR.InstructionCount; MachineFunctionPassCount.push_back(Count); return PreservedAnalyses::none(); } int &Count; std::vector &BeforeInitialization; std::vector &BeforeFinalization; std::vector &MachineFunctionPassCount; }; struct TestMachineModulePass : public PassInfoMixin { TestMachineModulePass(int &Count, std::vector &MachineModulePassCount) : Count(Count), MachineModulePassCount(MachineModulePassCount) {} Error run(Module &M, MachineFunctionAnalysisManager &MFAM) { MachineModuleInfo &MMI = MFAM.getResult(M); // + 1 Count += (MMI.getModule() == &M); MachineModulePassCount.push_back(Count); return Error::success(); } PreservedAnalyses run(MachineFunction &MF, MachineFunctionAnalysisManager &AM) { llvm_unreachable( "This should never be reached because this is machine module pass"); } int &Count; std::vector &MachineModulePassCount; }; std::unique_ptr parseIR(LLVMContext &Context, const char *IR) { SMDiagnostic Err; return parseAssemblyString(IR, Err, Context); } class PassManagerTest : public ::testing::Test { protected: LLVMContext Context; std::unique_ptr M; std::unique_ptr TM; public: PassManagerTest() : M(parseIR(Context, "define void @f() {\n" "entry:\n" " call void @g()\n" " call void @h()\n" " ret void\n" "}\n" "define void @g() {\n" " ret void\n" "}\n" "define void @h() {\n" " ret void\n" "}\n")) { // MachineModuleAnalysis needs a TargetMachine instance. llvm::InitializeAllTargets(); std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple()); std::string Error; const Target *TheTarget = TargetRegistry::lookupTarget(TripleName, Error); if (!TheTarget) return; TargetOptions Options; TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options, None)); } }; TEST_F(PassManagerTest, Basic) { if (!TM) GTEST_SKIP(); LLVMTargetMachine *LLVMTM = static_cast(TM.get()); M->setDataLayout(TM->createDataLayout()); LoopAnalysisManager LAM; FunctionAnalysisManager FAM; CGSCCAnalysisManager CGAM; ModuleAnalysisManager MAM; PassBuilder PB(TM.get()); PB.registerModuleAnalyses(MAM); PB.registerFunctionAnalyses(FAM); PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); FAM.registerPass([&] { return TestFunctionAnalysis(); }); FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); }); MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); MachineFunctionAnalysisManager MFAM; { // Test move assignment. MachineFunctionAnalysisManager NestedMFAM(FAM, MAM); NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); }); NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); }); MFAM = std::move(NestedMFAM); } int Count = 0; std::vector BeforeInitialization[2]; std::vector BeforeFinalization[2]; std::vector TestMachineFunctionCount[2]; std::vector TestMachineModuleCount[2]; MachineFunctionPassManager MFPM; { // Test move assignment. MachineFunctionPassManager NestedMFPM; NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0])); NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0], BeforeFinalization[0], TestMachineFunctionCount[0])); NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1])); NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], BeforeFinalization[1], TestMachineFunctionCount[1])); MFPM = std::move(NestedMFPM); } ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM))); // Check first machine module pass EXPECT_EQ(1u, TestMachineModuleCount[0].size()); EXPECT_EQ(3, TestMachineModuleCount[0][0]); // Check first machine function pass EXPECT_EQ(1u, BeforeInitialization[0].size()); EXPECT_EQ(1, BeforeInitialization[0][0]); EXPECT_EQ(3u, TestMachineFunctionCount[0].size()); EXPECT_EQ(10, TestMachineFunctionCount[0][0]); EXPECT_EQ(13, TestMachineFunctionCount[0][1]); EXPECT_EQ(16, TestMachineFunctionCount[0][2]); EXPECT_EQ(1u, BeforeFinalization[0].size()); EXPECT_EQ(31, BeforeFinalization[0][0]); // Check second machine module pass EXPECT_EQ(1u, TestMachineModuleCount[1].size()); EXPECT_EQ(17, TestMachineModuleCount[1][0]); // Check second machine function pass EXPECT_EQ(1u, BeforeInitialization[1].size()); EXPECT_EQ(2, BeforeInitialization[1][0]); EXPECT_EQ(3u, TestMachineFunctionCount[1].size()); EXPECT_EQ(24, TestMachineFunctionCount[1][0]); EXPECT_EQ(27, TestMachineFunctionCount[1][1]); EXPECT_EQ(30, TestMachineFunctionCount[1][2]); EXPECT_EQ(1u, BeforeFinalization[1].size()); EXPECT_EQ(32, BeforeFinalization[1][0]); EXPECT_EQ(32, Count); // doInitialization returns error Count = 10000; MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], BeforeFinalization[1], TestMachineFunctionCount[1])); std::string Message; llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { Message = Error.getMessage(); }); EXPECT_EQ(Message, DoInitErrMsg); // doFinalization returns error Count = 1000; MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], BeforeFinalization[1], TestMachineFunctionCount[1])); llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { Message = Error.getMessage(); }); EXPECT_EQ(Message, DoFinalErrMsg); } } // namespace