diff --git a/include/llvm/Transforms/IPO/OpenMPOpt.h b/include/llvm/Transforms/IPO/OpenMPOpt.h index d96187b73f9..9b72ee0afd2 100644 --- a/include/llvm/Transforms/IPO/OpenMPOpt.h +++ b/include/llvm/Transforms/IPO/OpenMPOpt.h @@ -33,6 +33,11 @@ struct OpenMPInModule { bool isKnown() { return Value != OpenMP::UNKNOWN; } operator bool() { return Value != OpenMP::NOT_FOUND; } + /// Does this function \p F contain any OpenMP runtime calls? + bool containsOMPRuntimeCalls(Function *F) const { + return FuncsWithOMPRuntimeCalls.contains(F); + } + /// Return the known kernels (=GPU entry points) in the module. SmallPtrSetImpl &getKernels() { return Kernels; } @@ -42,6 +47,11 @@ struct OpenMPInModule { private: enum class OpenMP { FOUND, NOT_FOUND, UNKNOWN } Value = OpenMP::UNKNOWN; + friend bool containsOpenMP(Module &M, OpenMPInModule &OMPInModule); + + /// In which functions are OpenMP runtime calls present? + SmallPtrSet FuncsWithOMPRuntimeCalls; + /// Collection of known kernels (=GPU entry points) in the module. SmallPtrSet Kernels; }; diff --git a/lib/Transforms/IPO/OpenMPOpt.cpp b/lib/Transforms/IPO/OpenMPOpt.cpp index f664a241737..93f1e5392eb 100644 --- a/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/lib/Transforms/IPO/OpenMPOpt.cpp @@ -1339,10 +1339,21 @@ PreservedAnalyses OpenMPOptPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); SmallVector SCC; - for (LazyCallGraph::Node &N : C) - SCC.push_back(&N.getFunction()); + // If there are kernels in the module, we have to run on all SCC's. + bool SCCIsInteresting = !OMPInModule.getKernels().empty(); + for (LazyCallGraph::Node &N : C) { + Function *Fn = &N.getFunction(); + SCC.push_back(Fn); - if (SCC.empty()) + // Do we already know that the SCC contains kernels, + // or that OpenMP functions are called from this SCC? + if (SCCIsInteresting) + continue; + // If not, let's check that. + SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); + } + + if (!SCCIsInteresting || SCC.empty()) return PreservedAnalyses::all(); FunctionAnalysisManager &FAM = @@ -1401,12 +1412,23 @@ struct OpenMPOptLegacyPass : public CallGraphSCCPass { return false; SmallVector SCC; - for (CallGraphNode *CGN : CGSCC) - if (Function *Fn = CGN->getFunction()) - if (!Fn->isDeclaration()) - SCC.push_back(Fn); + // If there are kernels in the module, we have to run on all SCC's. + bool SCCIsInteresting = !OMPInModule.getKernels().empty(); + for (CallGraphNode *CGN : CGSCC) { + Function *Fn = CGN->getFunction(); + if (!Fn || Fn->isDeclaration()) + continue; + SCC.push_back(Fn); - if (SCC.empty()) + // Do we already know that the SCC contains kernels, + // or that OpenMP functions are called from this SCC? + if (SCCIsInteresting) + continue; + // If not, let's check that. + SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn); + } + + if (!SCCIsInteresting || SCC.empty()) return false; CallGraph &CG = getAnalysis().getCallGraph(); @@ -1468,13 +1490,19 @@ bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) { if (OMPInModule.isKnown()) return OMPInModule; + auto RecordFunctionsContainingUsesOf = [&](Function *F) { + for (User *U : F->users()) + if (auto *I = dyn_cast(U)) + OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction()); + }; + // MSVC doesn't like long if-else chains for some reason and instead just // issues an error. Work around it.. do { #define OMP_RTL(_Enum, _Name, ...) \ - if (M.getFunction(_Name)) { \ + if (Function *F = M.getFunction(_Name)) { \ + RecordFunctionsContainingUsesOf(F); \ OMPInModule = true; \ - break; \ } #include "llvm/Frontend/OpenMP/OMPKinds.def" } while (false);