diff --git a/include/llvm/Frontend/OpenMP/OMPConstants.h b/include/llvm/Frontend/OpenMP/OMPConstants.h index a95628c056e..7cda3197473 100644 --- a/include/llvm/Frontend/OpenMP/OMPConstants.h +++ b/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -20,6 +20,7 @@ namespace llvm { class Type; class Module; +class ArrayType; class StructType; class PointerType; class FunctionType; @@ -95,6 +96,9 @@ StringRef getOpenMPDirectiveName(Directive D); namespace types { #define OMP_TYPE(VarName, InitValue) extern Type *VarName; +#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \ + extern ArrayType *VarName##Ty; \ + extern PointerType *VarName##PtrTy; #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \ extern FunctionType *VarName; \ extern PointerType *VarName##Ptr; diff --git a/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index d0e9e40370b..6979cacdf78 100644 --- a/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -14,9 +14,10 @@ #ifndef LLVM_OPENMP_IR_IRBUILDER_H #define LLVM_OPENMP_IR_IRBUILDER_H +#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/Allocator.h" namespace llvm { @@ -149,9 +150,8 @@ public: /// \param CanceledDirective The kind of directive that is cancled. /// /// \returns The insertion point after the barrier. - InsertPointTy CreateCancel(const LocationDescription &Loc, - Value *IfCondition, - omp::Directive CanceledDirective); + InsertPointTy CreateCancel(const LocationDescription &Loc, Value *IfCondition, + omp::Directive CanceledDirective); /// Generator for '#omp parallel' /// @@ -171,7 +171,6 @@ public: Value *IfCondition, Value *NumThreads, omp::ProcBindKind ProcBind, bool IsCancellable); - /// Generator for '#omp flush' /// /// \param Loc The location where the flush directive was encountered @@ -180,16 +179,15 @@ public: /// Generator for '#omp taskwait' /// /// \param Loc The location where the taskwait directive was encountered. - void CreateTaskwait(const LocationDescription& Loc); + void CreateTaskwait(const LocationDescription &Loc); /// Generator for '#omp taskyield' /// /// \param Loc The location where the taskyield directive was encountered. - void CreateTaskyield(const LocationDescription& Loc); + void CreateTaskyield(const LocationDescription &Loc); ///} - private: /// Update the internal location to \p Loc. bool updateToLocation(const LocationDescription &Loc) { @@ -292,6 +290,119 @@ private: /// Add a new region that will be outlined later. void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); } + + /// An ordered map of auto-generated variables to their unique names. + /// It stores variables with the following names: 1) ".gomp_critical_user_" + + /// + ".var" for "omp critical" directives; 2) + /// + ".cache." for cache for threadprivate + /// variables. + StringMap, BumpPtrAllocator> InternalVars; + +public: + /// Generator for '#omp master' + /// + /// \param Loc The insert and source location description. + /// \param BodyGenCB Callback that will generate the region code. + /// \param FiniCB Callback to finalize variable copies. + /// + /// \returns The insertion position *after* the master. + InsertPointTy CreateMaster(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB, + FinalizeCallbackTy FiniCB); + + /// Generator for '#omp master' + /// + /// \param Loc The insert and source location description. + /// \param BodyGenCB Callback that will generate the region body code. + /// \param FiniCB Callback to finalize variable copies. + /// \param CriticalName name of the lock used by the critical directive + /// \param HintInst Hint Instruction for hint clause associated with critical + /// + /// \returns The insertion position *after* the master. + InsertPointTy CreateCritical(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB, + FinalizeCallbackTy FiniCB, + StringRef CriticalName, Value *HintInst); + +private: + /// Common interface for generating entry calls for OMP Directives. + /// if the directive has a region/body, It will set the insertion + /// point to the body + /// + /// \param OMPD Directive to generate entry blocks for + /// \param EntryCall Call to the entry OMP Runtime Function + /// \param ExitBB block where the region ends. + /// \param Conditional indicate if the entry call result will be used + /// to evaluate a conditional of whether a thread will execute + /// body code or not. + /// + /// \return The insertion position in exit block + InsertPointTy emitCommonDirectiveEntry(omp::Directive OMPD, Value *EntryCall, + BasicBlock *ExitBB, + bool Conditional = false); + + /// Common interface to finalize the region + /// + /// \param OMPD Directive to generate exiting code for + /// \param FinIP Insertion point for emitting Finalization code and exit call + /// \param ExitCall Call to the ending OMP Runtime Function + /// \param HasFinalize indicate if the directive will require finalization + /// and has a finalization callback in the stack that + /// should be called. + /// + /// \return The insertion position in exit block + InsertPointTy emitCommonDirectiveExit(omp::Directive OMPD, + InsertPointTy FinIP, + Instruction *ExitCall, + bool HasFinalize = true); + + /// Common Interface to generate OMP inlined regions + /// + /// \param OMPD Directive to generate inlined region for + /// \param EntryCall Call to the entry OMP Runtime Function + /// \param ExitCall Call to the ending OMP Runtime Function + /// \param BodyGenCB Body code generation callback. + /// \param FiniCB Finalization Callback. Will be called when finalizing region + /// \param Conditional indicate if the entry call result will be used + /// to evaluate a conditional of whether a thread will execute + /// body code or not. + /// \param HasFinalize indicate if the directive will require finalization + /// and has a finalization callback in the stack that + /// should be called. + /// + /// \return The insertion point after the region + + InsertPointTy + EmitOMPInlinedRegion(omp::Directive OMPD, Instruction *EntryCall, + Instruction *ExitCall, BodyGenCallbackTy BodyGenCB, + FinalizeCallbackTy FiniCB, bool Conditional = false, + bool HasFinalize = true); + + /// Get the platform-specific name separator. + /// \param Parts different parts of the final name that needs separation + /// \param FirstSeparator First separator used between the initial two + /// parts of the name. + /// \param Separator separator used between all of the rest consecutinve + /// parts of the name + static std::string getNameWithSeparators(ArrayRef Parts, + StringRef FirstSeparator, + StringRef Separator); + + /// Gets (if variable with the given name already exist) or creates + /// internal global variable with the specified Name. The created variable has + /// linkage CommonLinkage by default and is initialized by null value. + /// \param Ty Type of the global variable. If it is exist already the type + /// must be the same. + /// \param Name Name of the variable. + Constant *getOrCreateOMPInternalVariable(Type *Ty, const Twine &Name, + unsigned AddressSpace = 0); + + /// Returns corresponding lock object for the specified critical region + /// name. If the lock object does not exist it is created, otherwise the + /// reference to the existing copy is returned. + /// \param CriticalName Name of the critical region. + /// + Value *getOMPCriticalRegionLock(StringRef CriticalName); }; } // end namespace llvm diff --git a/include/llvm/Frontend/OpenMP/OMPKinds.def b/include/llvm/Frontend/OpenMP/OMPKinds.def index b235a1e0574..43baf14e9bc 100644 --- a/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -122,6 +122,24 @@ __OMP_TYPE(Int32Ptr) ///} +/// array types +/// +///{ + +#ifndef OMP_ARRAY_TYPE +#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) +#endif + +#define __OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \ + OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) + +__OMP_ARRAY_TYPE(KmpCriticalName, Int32, 8) + +#undef __OMP_ARRAY_TYPE +#undef OMP_ARRAY_TYPE + +///} + /// Struct and function types /// ///{ @@ -209,6 +227,12 @@ __OMP_RTL(omp_set_max_active_levels, false, Void, Int32) __OMP_RTL(__last, false, Void, ) +__OMP_RTL(__kmpc_master, false, Int32, IdentPtr, Int32) +__OMP_RTL(__kmpc_end_master, false, Void, IdentPtr, Int32) +__OMP_RTL(__kmpc_critical, false, Void, IdentPtr, Int32, KmpCriticalNamePtrTy) +__OMP_RTL(__kmpc_critical_with_hint, false, Void, IdentPtr, Int32, KmpCriticalNamePtrTy, Int32) +__OMP_RTL(__kmpc_end_critical, false, Void, IdentPtr, Int32, KmpCriticalNamePtrTy) + #undef __OMP_RTL #undef OMP_RTL diff --git a/lib/Frontend/OpenMP/OMPConstants.cpp b/lib/Frontend/OpenMP/OMPConstants.cpp index ec0733903e9..6ee44958d1c 100644 --- a/lib/Frontend/OpenMP/OMPConstants.cpp +++ b/lib/Frontend/OpenMP/OMPConstants.cpp @@ -36,14 +36,16 @@ StringRef llvm::omp::getOpenMPDirectiveName(Directive Kind) { llvm_unreachable("Invalid OpenMP directive kind"); } -/// Declarations for LLVM-IR types (simple, function and structure) are +/// Declarations for LLVM-IR types (simple, array, function and structure) are /// generated below. Their names are defined and used in OpenMPKinds.def. Here /// we provide the declarations, the initializeTypes function will provide the /// values. /// ///{ - #define OMP_TYPE(VarName, InitValue) Type *llvm::omp::types::VarName = nullptr; +#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \ + ArrayType *llvm::omp::types::VarName##Ty = nullptr; \ + PointerType *llvm::omp::types::VarName##PtrTy = nullptr; #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \ FunctionType *llvm::omp::types::VarName = nullptr; \ PointerType *llvm::omp::types::VarName##Ptr = nullptr; @@ -63,6 +65,9 @@ void llvm::omp::types::initializeTypes(Module &M) { // the llvm::PointerTypes of them for easy access later. StructType *T; #define OMP_TYPE(VarName, InitValue) VarName = InitValue; +#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \ + VarName##Ty = ArrayType::get(ElemTy, ArraySize); \ + VarName##PtrTy = PointerType::getUnqual(VarName##Ty); #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \ VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \ VarName##Ptr = PointerType::getUnqual(VarName); diff --git a/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/lib/Frontend/OpenMP/OMPIRBuilder.cpp index b011a3ee9b9..5706e7c7527 100644 --- a/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -18,8 +18,8 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfo.h" -#include "llvm/IR/MDBuilder.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Error.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -676,8 +676,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel( return AfterIP; } -void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) -{ +void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) { // Build call void __kmpc_flush(ident_t *loc) Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); Value *Args[] = {getOrCreateIdent(SrcLocStr)}; @@ -685,10 +684,9 @@ void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) Builder.CreateCall(getOrCreateRuntimeFunction(OMPRTL___kmpc_flush), Args); } -void OpenMPIRBuilder::CreateFlush(const LocationDescription &Loc) -{ +void OpenMPIRBuilder::CreateFlush(const LocationDescription &Loc) { if (!updateToLocation(Loc)) - return; + return; emitFlush(Loc); } @@ -726,3 +724,239 @@ void OpenMPIRBuilder::CreateTaskyield(const LocationDescription &Loc) { return; emitTaskyieldImpl(Loc); } + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::CreateMaster(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB, + FinalizeCallbackTy FiniCB) { + + if (!updateToLocation(Loc)) + return Loc.IP; + + Directive OMPD = Directive::OMPD_master; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + Value *ThreadId = getOrCreateThreadID(Ident); + Value *Args[] = {Ident, ThreadId}; + + Function *EntryRTLFn = getOrCreateRuntimeFunction(OMPRTL___kmpc_master); + Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args); + + Function *ExitRTLFn = getOrCreateRuntimeFunction(OMPRTL___kmpc_end_master); + Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args); + + return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB, + /*Conditional*/ true, /*hasFinalize*/ true); +} + +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::CreateCritical( + const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, + FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) { + + if (!updateToLocation(Loc)) + return Loc.IP; + + Directive OMPD = Directive::OMPD_critical; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + Value *ThreadId = getOrCreateThreadID(Ident); + Value *LockVar = getOMPCriticalRegionLock(CriticalName); + Value *Args[] = {Ident, ThreadId, LockVar}; + + SmallVector EnterArgs(std::begin(Args), std::end(Args)); + Function *RTFn = nullptr; + if (HintInst) { + // Add Hint to entry Args and create call + EnterArgs.push_back(HintInst); + RTFn = getOrCreateRuntimeFunction(OMPRTL___kmpc_critical_with_hint); + } else { + RTFn = getOrCreateRuntimeFunction(OMPRTL___kmpc_critical); + } + Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs); + + Function *ExitRTLFn = getOrCreateRuntimeFunction(OMPRTL___kmpc_end_critical); + Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args); + + return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB, + /*Conditional*/ false, /*hasFinalize*/ true); +} + +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion( + Directive OMPD, Instruction *EntryCall, Instruction *ExitCall, + BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional, + bool HasFinalize) { + + if (HasFinalize) + FinalizationStack.push_back({FiniCB, OMPD, /*IsCancellable*/ false}); + + // Create inlined region's entry and body blocks, in preparation + // for conditional creation + BasicBlock *EntryBB = Builder.GetInsertBlock(); + Instruction *SplitPos = EntryBB->getTerminator(); + if (!isa_and_nonnull(SplitPos)) + SplitPos = new UnreachableInst(Builder.getContext(), EntryBB); + BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end"); + BasicBlock *FiniBB = + EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize"); + + Builder.SetInsertPoint(EntryBB->getTerminator()); + emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional); + + // generate body + BodyGenCB(/* AllocaIP */ InsertPointTy(), + /* CodeGenIP */ Builder.saveIP(), *FiniBB); + + // If we didn't emit a branch to FiniBB during body generation, it means + // FiniBB is unreachable (e.g. while(1);). stop generating all the + // unreachable blocks, and remove anything we are not going to use. + auto SkipEmittingRegion = FiniBB->hasNPredecessors(0); + if (SkipEmittingRegion) { + FiniBB->eraseFromParent(); + ExitCall->eraseFromParent(); + // Discard finalization if we have it. + if (HasFinalize) { + assert(!FinalizationStack.empty() && + "Unexpected finalization stack state!"); + FinalizationStack.pop_back(); + } + } else { + // emit exit call and do any needed finalization. + auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt()); + assert(FiniBB->getTerminator()->getNumSuccessors() == 1 && + FiniBB->getTerminator()->getSuccessor(0) == ExitBB && + "Unexpected control flow graph state!!"); + emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize); + assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB && + "Unexpected Control Flow State!"); + MergeBlockIntoPredecessor(FiniBB); + } + + // If we are skipping the region of a non conditional, remove the exit + // block, and clear the builder's insertion point. + assert(SplitPos->getParent() == ExitBB && + "Unexpected Insertion point location!"); + if (!Conditional && SkipEmittingRegion) { + ExitBB->eraseFromParent(); + Builder.ClearInsertionPoint(); + } else { + auto merged = MergeBlockIntoPredecessor(ExitBB); + BasicBlock *ExitPredBB = SplitPos->getParent(); + auto InsertBB = merged ? ExitPredBB : ExitBB; + if (!isa_and_nonnull(SplitPos)) + SplitPos->eraseFromParent(); + Builder.SetInsertPoint(InsertBB); + } + + return Builder.saveIP(); +} + +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry( + Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) { + + // if nothing to do, Return current insertion point. + if (!Conditional) + return Builder.saveIP(); + + BasicBlock *EntryBB = Builder.GetInsertBlock(); + Value *CallBool = Builder.CreateIsNotNull(EntryCall); + auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body"); + auto *UI = new UnreachableInst(Builder.getContext(), ThenBB); + + // Emit thenBB and set the Builder's insertion point there for + // body generation next. Place the block after the current block. + Function *CurFn = EntryBB->getParent(); + CurFn->getBasicBlockList().insertAfter(EntryBB->getIterator(), ThenBB); + + // Move Entry branch to end of ThenBB, and replace with conditional + // branch (If-stmt) + Instruction *EntryBBTI = EntryBB->getTerminator(); + Builder.CreateCondBr(CallBool, ThenBB, ExitBB); + EntryBBTI->removeFromParent(); + Builder.SetInsertPoint(UI); + Builder.Insert(EntryBBTI); + UI->eraseFromParent(); + Builder.SetInsertPoint(ThenBB->getTerminator()); + + // return an insertion point to ExitBB. + return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt()); +} + +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit( + omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall, + bool HasFinalize) { + + Builder.restoreIP(FinIP); + + // If there is finalization to do, emit it before the exit call + if (HasFinalize) { + assert(!FinalizationStack.empty() && + "Unexpected finalization stack state!"); + + FinalizationInfo Fi = FinalizationStack.pop_back_val(); + assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!"); + + Fi.FiniCB(FinIP); + + BasicBlock *FiniBB = FinIP.getBlock(); + Instruction *FiniBBTI = FiniBB->getTerminator(); + + // set Builder IP for call creation + Builder.SetInsertPoint(FiniBBTI); + } + + // place the Exitcall as last instruction before Finalization block terminator + ExitCall->removeFromParent(); + Builder.Insert(ExitCall); + + return IRBuilder<>::InsertPoint(ExitCall->getParent(), + ExitCall->getIterator()); +} + +std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef Parts, + StringRef FirstSeparator, + StringRef Separator) { + SmallString<128> Buffer; + llvm::raw_svector_ostream OS(Buffer); + StringRef Sep = FirstSeparator; + for (StringRef Part : Parts) { + OS << Sep << Part; + Sep = Separator; + } + return OS.str().str(); +} + +Constant *OpenMPIRBuilder::getOrCreateOMPInternalVariable( + llvm::Type *Ty, const llvm::Twine &Name, unsigned AddressSpace) { + // TODO: Replace the twine arg with stringref to get rid of the conversion + // logic. However This is taken from current implementation in clang as is. + // Since this method is used in many places exclusively for OMP internal use + // we will keep it as is for temporarily until we move all users to the + // builder and then, if possible, fix it everywhere in one go. + SmallString<256> Buffer; + llvm::raw_svector_ostream Out(Buffer); + Out << Name; + StringRef RuntimeName = Out.str(); + auto &Elem = *InternalVars.try_emplace(RuntimeName, nullptr).first; + if (Elem.second) { + assert(Elem.second->getType()->getPointerElementType() == Ty && + "OMP internal variable has different type than requested"); + } else { + // TODO: investigate the appropriate linkage type used for the global + // variable for possibly changing that to internal or private, or maybe + // create different versions of the function for different OMP internal + // variables. + Elem.second = new llvm::GlobalVariable( + M, Ty, /*IsConstant*/ false, llvm::GlobalValue::CommonLinkage, + llvm::Constant::getNullValue(Ty), Elem.first(), + /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, + AddressSpace); + } + + return Elem.second; +} + +Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) { + std::string Prefix = Twine("gomp_critical_user_", CriticalName).str(); + std::string Name = getNameWithSeparators({Prefix, "var"}, ".", "."); + return getOrCreateOMPInternalVariable(KmpCriticalNameTy, Name); +} diff --git a/unittests/Frontend/OpenMPIRBuilderTest.cpp b/unittests/Frontend/OpenMPIRBuilderTest.cpp index 339e1f97759..b356f6d66e1 100644 --- a/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -622,4 +622,161 @@ TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) { } } +TEST_F(OpenMPIRBuilderTest, MasterDirective) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + AllocaInst *PrivAI = nullptr; + + BasicBlock *EntryBB = nullptr; + BasicBlock *ExitBB = nullptr; + BasicBlock *ThenBB = nullptr; + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &FiniBB) { + if (AllocaIP.isSet()) + Builder.restoreIP(AllocaIP); + else + Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt())); + PrivAI = Builder.CreateAlloca(F->arg_begin()->getType()); + Builder.CreateStore(F->arg_begin(), PrivAI); + + llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock(); + llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint(); + EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst); + + Builder.restoreIP(CodeGenIP); + + // collect some info for checks later + ExitBB = FiniBB.getUniqueSuccessor(); + ThenBB = Builder.GetInsertBlock(); + EntryBB = ThenBB->getUniquePredecessor(); + + // simple instructions for body + Value *PrivLoad = Builder.CreateLoad(PrivAI, "local.use"); + Builder.CreateICmpNE(F->arg_begin(), PrivLoad); + }; + + auto FiniCB = [&](InsertPointTy IP) { + BasicBlock *IPBB = IP.getBlock(); + EXPECT_NE(IPBB->end(), IP.getPoint()); + }; + + Builder.restoreIP(OMPBuilder.CreateMaster(Builder, BodyGenCB, FiniCB)); + Value *EntryBBTI = EntryBB->getTerminator(); + EXPECT_NE(EntryBBTI, nullptr); + EXPECT_TRUE(isa(EntryBBTI)); + BranchInst *EntryBr = cast(EntryBB->getTerminator()); + EXPECT_TRUE(EntryBr->isConditional()); + EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB); + EXPECT_EQ(ThenBB->getUniqueSuccessor(), ExitBB); + EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB); + + CmpInst *CondInst = cast(EntryBr->getCondition()); + EXPECT_TRUE(isa(CondInst->getOperand(0))); + + CallInst *MasterEntryCI = cast(CondInst->getOperand(0)); + EXPECT_EQ(MasterEntryCI->getNumArgOperands(), 2U); + EXPECT_EQ(MasterEntryCI->getCalledFunction()->getName(), "__kmpc_master"); + EXPECT_TRUE(isa(MasterEntryCI->getArgOperand(0))); + + CallInst *MasterEndCI = nullptr; + for (auto &FI : *ThenBB) { + Instruction *cur = &FI; + if (isa(cur)) { + MasterEndCI = cast(cur); + if (MasterEndCI->getCalledFunction()->getName() == "__kmpc_end_master") + break; + MasterEndCI = nullptr; + } + } + EXPECT_NE(MasterEndCI, nullptr); + EXPECT_EQ(MasterEndCI->getNumArgOperands(), 2U); + EXPECT_TRUE(isa(MasterEndCI->getArgOperand(0))); + EXPECT_EQ(MasterEndCI->getArgOperand(1), MasterEntryCI->getArgOperand(1)); +} + +TEST_F(OpenMPIRBuilderTest, CriticalDirective) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + AllocaInst *PrivAI = Builder.CreateAlloca(F->arg_begin()->getType()); + + BasicBlock *EntryBB = nullptr; + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &FiniBB) { + // collect some info for checks later + EntryBB = FiniBB.getUniquePredecessor(); + + // actual start for bodyCB + llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock(); + llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint(); + EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst); + EXPECT_EQ(EntryBB, CodeGenIPBB); + + // body begin + Builder.restoreIP(CodeGenIP); + Builder.CreateStore(F->arg_begin(), PrivAI); + Value *PrivLoad = Builder.CreateLoad(PrivAI, "local.use"); + Builder.CreateICmpNE(F->arg_begin(), PrivLoad); + }; + + auto FiniCB = [&](InsertPointTy IP) { + BasicBlock *IPBB = IP.getBlock(); + EXPECT_NE(IPBB->end(), IP.getPoint()); + }; + + Builder.restoreIP(OMPBuilder.CreateCritical(Builder, BodyGenCB, FiniCB, + "testCRT", nullptr)); + + Value *EntryBBTI = EntryBB->getTerminator(); + EXPECT_EQ(EntryBBTI, nullptr); + + CallInst *CriticalEntryCI = nullptr; + for (auto &EI : *EntryBB) { + Instruction *cur = &EI; + if (isa(cur)) { + CriticalEntryCI = cast(cur); + if (CriticalEntryCI->getCalledFunction()->getName() == "__kmpc_critical") + break; + CriticalEntryCI = nullptr; + } + } + EXPECT_NE(CriticalEntryCI, nullptr); + EXPECT_EQ(CriticalEntryCI->getNumArgOperands(), 3U); + EXPECT_EQ(CriticalEntryCI->getCalledFunction()->getName(), "__kmpc_critical"); + EXPECT_TRUE(isa(CriticalEntryCI->getArgOperand(0))); + + CallInst *CriticalEndCI = nullptr; + for (auto &FI : *EntryBB) { + Instruction *cur = &FI; + if (isa(cur)) { + CriticalEndCI = cast(cur); + if (CriticalEndCI->getCalledFunction()->getName() == + "__kmpc_end_critical") + break; + CriticalEndCI = nullptr; + } + } + EXPECT_NE(CriticalEndCI, nullptr); + EXPECT_EQ(CriticalEndCI->getNumArgOperands(), 3U); + EXPECT_TRUE(isa(CriticalEndCI->getArgOperand(0))); + EXPECT_EQ(CriticalEndCI->getArgOperand(1), CriticalEntryCI->getArgOperand(1)); + PointerType *CriticalNamePtrTy = + PointerType::getUnqual(ArrayType::get(Type::getInt32Ty(Ctx), 8)); + EXPECT_EQ(CriticalEndCI->getArgOperand(2), CriticalEntryCI->getArgOperand(2)); + EXPECT_EQ(CriticalEndCI->getArgOperand(2)->getType(), CriticalNamePtrTy); +} + } // namespace