diff --git a/include/llvm/ExecutionEngine/JITSymbol.h b/include/llvm/ExecutionEngine/JITSymbol.h index 0e33f014c7f..b26ab4698b8 100644 --- a/include/llvm/ExecutionEngine/JITSymbol.h +++ b/include/llvm/ExecutionEngine/JITSymbol.h @@ -298,7 +298,6 @@ class JITSymbolResolver { public: using LookupSet = std::set; using LookupResult = std::map; - using LookupFlagsResult = std::map; virtual ~JITSymbolResolver() = default; @@ -309,11 +308,11 @@ public: /// resolved, or if the resolution process itself triggers an error. virtual Expected lookup(const LookupSet &Symbols) = 0; - /// Returns the symbol flags for each of the given symbols. - /// - /// This method does NOT return an error if any of the given symbols is - /// missing. Instead, that symbol will be left out of the result map. - virtual Expected lookupFlags(const LookupSet &Symbols) = 0; + /// Returns the subset of the given symbols that should be materialized by + /// the caller. Only weak/common symbols should be looked up, as strong + /// definitions are implicitly always part of the caller's responsibility. + virtual Expected + getResponsibilitySet(const LookupSet &Symbols) = 0; private: virtual void anchor(); @@ -329,7 +328,7 @@ public: /// Performs flags lookup by calling findSymbolInLogicalDylib and /// returning the flags value for that symbol. - Expected lookupFlags(const LookupSet &Symbols) final; + Expected getResponsibilitySet(const LookupSet &Symbols) final; /// This method returns the address of the specified symbol if it exists /// within the logical dynamic library represented by this JITSymbolResolver. diff --git a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h index cd5f1640ac2..20382f5086d 100644 --- a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h +++ b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h @@ -500,28 +500,29 @@ private: auto GVsResolver = createSymbolResolver( [&LD, LegacyLookup](const SymbolNameSet &Symbols) { - auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup); + auto RS = getResponsibilitySetWithLegacyFn(Symbols, LegacyLookup); - if (!SymbolFlags) { - logAllUnhandledErrors(SymbolFlags.takeError(), errs(), - "CODLayer/GVsResolver flags lookup failed: "); - return SymbolFlagsMap(); + if (!RS) { + logAllUnhandledErrors( + RS.takeError(), errs(), + "CODLayer/GVsResolver responsibility set lookup failed: "); + return SymbolNameSet(); } - if (SymbolFlags->size() == Symbols.size()) - return *SymbolFlags; + if (RS->size() == Symbols.size()) + return *RS; SymbolNameSet NotFoundViaLegacyLookup; for (auto &S : Symbols) - if (!SymbolFlags->count(S)) + if (!RS->count(S)) NotFoundViaLegacyLookup.insert(S); - auto SymbolFlags2 = - LD.BackingResolver->lookupFlags(NotFoundViaLegacyLookup); + auto RS2 = + LD.BackingResolver->getResponsibilitySet(NotFoundViaLegacyLookup); - for (auto &KV : SymbolFlags2) - (*SymbolFlags)[KV.first] = std::move(KV.second); + for (auto &S : RS2) + (*RS).insert(S); - return *SymbolFlags; + return *RS; }, [this, &LD, LegacyLookup](std::shared_ptr Query, @@ -669,28 +670,29 @@ private: // Create memory manager and symbol resolver. auto Resolver = createSymbolResolver( [&LD, LegacyLookup](const SymbolNameSet &Symbols) { - auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup); - if (!SymbolFlags) { - logAllUnhandledErrors(SymbolFlags.takeError(), errs(), - "CODLayer/SubResolver flags lookup failed: "); - return SymbolFlagsMap(); + auto RS = getResponsibilitySetWithLegacyFn(Symbols, LegacyLookup); + if (!RS) { + logAllUnhandledErrors( + RS.takeError(), errs(), + "CODLayer/SubResolver responsibility set lookup failed: "); + return SymbolNameSet(); } - if (SymbolFlags->size() == Symbols.size()) - return *SymbolFlags; + if (RS->size() == Symbols.size()) + return *RS; SymbolNameSet NotFoundViaLegacyLookup; for (auto &S : Symbols) - if (!SymbolFlags->count(S)) + if (!RS->count(S)) NotFoundViaLegacyLookup.insert(S); - auto SymbolFlags2 = - LD.BackingResolver->lookupFlags(NotFoundViaLegacyLookup); + auto RS2 = + LD.BackingResolver->getResponsibilitySet(NotFoundViaLegacyLookup); - for (auto &KV : SymbolFlags2) - (*SymbolFlags)[KV.first] = std::move(KV.second); + for (auto &S : RS2) + (*RS).insert(S); - return *SymbolFlags; + return *RS; }, [this, &LD, LegacyLookup](std::shared_ptr Q, SymbolNameSet Symbols) { diff --git a/include/llvm/ExecutionEngine/Orc/Legacy.h b/include/llvm/ExecutionEngine/Orc/Legacy.h index 52c8c162ff0..b8730496229 100644 --- a/include/llvm/ExecutionEngine/Orc/Legacy.h +++ b/include/llvm/ExecutionEngine/Orc/Legacy.h @@ -31,12 +31,12 @@ class SymbolResolver { public: virtual ~SymbolResolver() = default; - /// Returns the flags for each symbol in Symbols that can be found, - /// along with the set of symbol that could not be found. - virtual SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) = 0; + /// Returns the subset of the given symbols that the caller is responsible for + /// materializing. + virtual SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) = 0; /// For each symbol in Symbols that can be found, assigns that symbols - /// value in Query. Returns the set of symbols that could not be found. + /// value in Query. Returns the set of symbols that could not be found. virtual SymbolNameSet lookup(std::shared_ptr Query, SymbolNameSet Symbols) = 0; @@ -46,16 +46,18 @@ private: /// Implements SymbolResolver with a pair of supplied function objects /// for convenience. See createSymbolResolver. -template +template class LambdaSymbolResolver final : public SymbolResolver { public: - template - LambdaSymbolResolver(LookupFlagsFnRef &&LookupFlags, LookupFnRef &&Lookup) - : LookupFlags(std::forward(LookupFlags)), + template + LambdaSymbolResolver(GetResponsibilitySetFnRef &&GetResponsibilitySet, + LookupFnRef &&Lookup) + : GetResponsibilitySet( + std::forward(GetResponsibilitySet)), Lookup(std::forward(Lookup)) {} - SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) final { - return LookupFlags(Symbols); + SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) final { + return GetResponsibilitySet(Symbols); } SymbolNameSet lookup(std::shared_ptr Query, @@ -64,33 +66,35 @@ public: } private: - LookupFlagsFn LookupFlags; + GetResponsibilitySetFn GetResponsibilitySet; LookupFn Lookup; }; /// Creates a SymbolResolver implementation from the pair of supplied /// function objects. -template +template std::unique_ptr::type>::type, + typename std::remove_reference::type>::type, typename std::remove_cv< typename std::remove_reference::type>::type>> -createSymbolResolver(LookupFlagsFn &&LookupFlags, LookupFn &&Lookup) { +createSymbolResolver(GetResponsibilitySetFn &&GetResponsibilitySet, + LookupFn &&Lookup) { using LambdaSymbolResolverImpl = LambdaSymbolResolver< typename std::remove_cv< - typename std::remove_reference::type>::type, + typename std::remove_reference::type>::type, typename std::remove_cv< typename std::remove_reference::type>::type>; return llvm::make_unique( - std::forward(LookupFlags), std::forward(Lookup)); + std::forward(GetResponsibilitySet), + std::forward(Lookup)); } class JITSymbolResolverAdapter : public JITSymbolResolver { public: JITSymbolResolverAdapter(ExecutionSession &ES, SymbolResolver &R, MaterializationResponsibility *MR); - Expected lookupFlags(const LookupSet &Symbols) override; + Expected getResponsibilitySet(const LookupSet &Symbols) override; Expected lookup(const LookupSet &Symbols) override; private: @@ -100,27 +104,29 @@ private: MaterializationResponsibility *MR; }; -/// Use the given legacy-style FindSymbol function (i.e. a function that -/// takes a const std::string& or StringRef and returns a JITSymbol) to -/// find the flags for each symbol in Symbols and store their flags in -/// SymbolFlags. If any JITSymbol returned by FindSymbol is in an error -/// state the function returns immediately with that error, otherwise it -/// returns the set of symbols not found. +/// Use the given legacy-style FindSymbol function (i.e. a function that takes +/// a const std::string& or StringRef and returns a JITSymbol) to get the +/// subset of symbols that the caller is responsible for materializing. If any +/// JITSymbol returned by FindSymbol is in an error state the function returns +/// immediately with that error. /// -/// Useful for implementing lookupFlags bodies that query legacy resolvers. +/// Useful for implementing getResponsibilitySet bodies that query legacy +/// resolvers. template -Expected lookupFlagsWithLegacyFn(const SymbolNameSet &Symbols, - FindSymbolFn FindSymbol) { - SymbolFlagsMap SymbolFlags; +Expected +getResponsibilitySetWithLegacyFn(const SymbolNameSet &Symbols, + FindSymbolFn FindSymbol) { + SymbolNameSet Result; for (auto &S : Symbols) { - if (JITSymbol Sym = FindSymbol(*S)) - SymbolFlags[S] = Sym.getFlags(); - else if (auto Err = Sym.takeError()) + if (JITSymbol Sym = FindSymbol(*S)) { + if (!Sym.getFlags().isStrong()) + Result.insert(S); + } else if (auto Err = Sym.takeError()) return std::move(Err); } - return SymbolFlags; + return Result; } /// Use the given legacy-style FindSymbol function (i.e. a function that @@ -177,12 +183,13 @@ public: : ES(ES), LegacyLookup(std::move(LegacyLookup)), ReportError(std::move(ReportError)) {} - SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) final { - if (auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup)) - return std::move(*SymbolFlags); + SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) final { + if (auto ResponsibilitySet = + getResponsibilitySetWithLegacyFn(Symbols, LegacyLookup)) + return std::move(*ResponsibilitySet); else { - ReportError(SymbolFlags.takeError()); - return SymbolFlagsMap(); + ReportError(ResponsibilitySet.takeError()); + return SymbolNameSet(); } } diff --git a/include/llvm/ExecutionEngine/Orc/NullResolver.h b/include/llvm/ExecutionEngine/Orc/NullResolver.h index 3dd3cfe05b8..03fefb69a92 100644 --- a/include/llvm/ExecutionEngine/Orc/NullResolver.h +++ b/include/llvm/ExecutionEngine/Orc/NullResolver.h @@ -23,10 +23,10 @@ namespace orc { class NullResolver : public SymbolResolver { public: - SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) override; + SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) final; SymbolNameSet lookup(std::shared_ptr Query, - SymbolNameSet Symbols) override; + SymbolNameSet Symbols) final; }; /// SymbolResolver impliementation that rejects all resolution requests. diff --git a/lib/ExecutionEngine/Orc/Legacy.cpp b/lib/ExecutionEngine/Orc/Legacy.cpp index 18be9a042f7..517176e5f42 100644 --- a/lib/ExecutionEngine/Orc/Legacy.cpp +++ b/lib/ExecutionEngine/Orc/Legacy.cpp @@ -48,17 +48,17 @@ JITSymbolResolverAdapter::lookup(const LookupSet &Symbols) { return Result; } -Expected -JITSymbolResolverAdapter::lookupFlags(const LookupSet &Symbols) { +Expected +JITSymbolResolverAdapter::getResponsibilitySet(const LookupSet &Symbols) { SymbolNameSet InternedSymbols; for (auto &S : Symbols) InternedSymbols.insert(ES.getSymbolStringPool().intern(S)); - SymbolFlagsMap SymbolFlags = R.lookupFlags(InternedSymbols); - LookupFlagsResult Result; - for (auto &KV : SymbolFlags) { - ResolvedStrings.insert(KV.first); - Result[*KV.first] = KV.second; + auto InternedResult = R.getResponsibilitySet(InternedSymbols); + LookupSet Result; + for (auto &S : InternedResult) { + ResolvedStrings.insert(S); + Result.insert(*S); } return Result; diff --git a/lib/ExecutionEngine/Orc/NullResolver.cpp b/lib/ExecutionEngine/Orc/NullResolver.cpp index 3796e3d37bc..922fc6f021c 100644 --- a/lib/ExecutionEngine/Orc/NullResolver.cpp +++ b/lib/ExecutionEngine/Orc/NullResolver.cpp @@ -14,8 +14,8 @@ namespace llvm { namespace orc { -SymbolFlagsMap NullResolver::lookupFlags(const SymbolNameSet &Symbols) { - return SymbolFlagsMap(); +SymbolNameSet NullResolver::getResponsibilitySet(const SymbolNameSet &Symbols) { + return Symbols; } SymbolNameSet diff --git a/lib/ExecutionEngine/Orc/OrcCBindingsStack.h b/lib/ExecutionEngine/Orc/OrcCBindingsStack.h index b9f8a370d2f..a67215f659a 100644 --- a/lib/ExecutionEngine/Orc/OrcCBindingsStack.h +++ b/lib/ExecutionEngine/Orc/OrcCBindingsStack.h @@ -129,20 +129,21 @@ private: : Stack(Stack), ExternalResolver(std::move(ExternalResolver)), ExternalResolverCtx(std::move(ExternalResolverCtx)) {} - orc::SymbolFlagsMap - lookupFlags(const orc::SymbolNameSet &Symbols) override { - orc::SymbolFlagsMap SymbolFlags; + orc::SymbolNameSet + getResponsibilitySet(const orc::SymbolNameSet &Symbols) override { + orc::SymbolNameSet Result; for (auto &S : Symbols) { - if (auto Sym = findSymbol(*S)) - SymbolFlags[S] = Sym.getFlags(); - else if (auto Err = Sym.takeError()) { + if (auto Sym = findSymbol(*S)) { + if (!Sym.getFlags().isStrong()) + Result.insert(S); + } else if (auto Err = Sym.takeError()) { Stack.reportError(std::move(Err)); - return orc::SymbolFlagsMap(); + return orc::SymbolNameSet(); } } - return SymbolFlags; + return Result; } orc::SymbolNameSet diff --git a/lib/ExecutionEngine/Orc/OrcMCJITReplacement.h b/lib/ExecutionEngine/Orc/OrcMCJITReplacement.h index abe89ce70af..6515c0da360 100644 --- a/lib/ExecutionEngine/Orc/OrcMCJITReplacement.h +++ b/lib/ExecutionEngine/Orc/OrcMCJITReplacement.h @@ -144,26 +144,29 @@ class OrcMCJITReplacement : public ExecutionEngine { public: LinkingORCResolver(OrcMCJITReplacement &M) : M(M) {} - SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) override { - SymbolFlagsMap SymbolFlags; + SymbolNameSet getResponsibilitySet(const SymbolNameSet &Symbols) override { + SymbolNameSet Result; for (auto &S : Symbols) { if (auto Sym = M.findMangledSymbol(*S)) { - SymbolFlags[S] = Sym.getFlags(); + if (!Sym.getFlags().isStrong()) + Result.insert(S); } else if (auto Err = Sym.takeError()) { M.reportError(std::move(Err)); - return SymbolFlagsMap(); + return SymbolNameSet(); } else { if (auto Sym2 = M.ClientResolver->findSymbolInLogicalDylib(*S)) { - SymbolFlags[S] = Sym2.getFlags(); + if (!Sym2.getFlags().isStrong()) + Result.insert(S); } else if (auto Err = Sym2.takeError()) { M.reportError(std::move(Err)); - return SymbolFlagsMap(); - } + return SymbolNameSet(); + } else + Result.insert(S); } } - return SymbolFlags; + return Result; } SymbolNameSet lookup(std::shared_ptr Query, diff --git a/lib/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.cpp b/lib/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.cpp index 1dfa90acb76..f82f5ecfed5 100644 --- a/lib/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.cpp +++ b/lib/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.cpp @@ -44,27 +44,13 @@ public: return Result; } - Expected lookupFlags(const LookupSet &Symbols) { - auto &ES = MR.getTargetJITDylib().getExecutionSession(); + Expected getResponsibilitySet(const LookupSet &Symbols) { + LookupSet Result; - SymbolNameSet InternedSymbols; - - for (auto &S : Symbols) - InternedSymbols.insert(ES.getSymbolStringPool().intern(S)); - - SymbolFlagsMap InternedResult; - MR.getTargetJITDylib().withSearchOrderDo([&](const JITDylibList &JDs) { - // An empty search order is pathalogical, but allowed. - if (JDs.empty()) - return; - - assert(JDs.front() && "VSOList entry can not be null"); - InternedResult = JDs.front()->lookupFlags(InternedSymbols); - }); - - LookupFlagsResult Result; - for (auto &KV : InternedResult) - Result[*KV.first] = std::move(KV.second); + for (auto &KV : MR.getSymbols()) { + if (Symbols.count(*KV.first)) + Result.insert(*KV.first); + } return Result; } diff --git a/lib/ExecutionEngine/RuntimeDyld/JITSymbol.cpp b/lib/ExecutionEngine/RuntimeDyld/JITSymbol.cpp index 88b0cd0e336..d865216cf31 100644 --- a/lib/ExecutionEngine/RuntimeDyld/JITSymbol.cpp +++ b/lib/ExecutionEngine/RuntimeDyld/JITSymbol.cpp @@ -94,16 +94,24 @@ LegacyJITSymbolResolver::lookup(const LookupSet &Symbols) { /// Performs flags lookup by calling findSymbolInLogicalDylib and /// returning the flags value for that symbol. -Expected -LegacyJITSymbolResolver::lookupFlags(const LookupSet &Symbols) { - JITSymbolResolver::LookupFlagsResult Result; +Expected +LegacyJITSymbolResolver::getResponsibilitySet(const LookupSet &Symbols) { + JITSymbolResolver::LookupSet Result; for (auto &Symbol : Symbols) { std::string SymName = Symbol.str(); - if (auto Sym = findSymbolInLogicalDylib(SymName)) - Result[Symbol] = Sym.getFlags(); - else if (auto Err = Sym.takeError()) + if (auto Sym = findSymbolInLogicalDylib(SymName)) { + // If there's an existing def but it is not strong, then the caller is + // responsible for it. + if (!Sym.getFlags().isStrong()) + Result.insert(Symbol); + } else if (auto Err = Sym.takeError()) return std::move(Err); + else { + // If there is no existing definition then the caller is responsible for + // it. + Result.insert(Symbol); + } } return std::move(Result); diff --git a/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp b/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp index df4bd5728a6..43c99b52c34 100644 --- a/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp +++ b/lib/ExecutionEngine/RuntimeDyld/RuntimeDyld.cpp @@ -204,7 +204,7 @@ RuntimeDyldImpl::loadObjectImpl(const object::ObjectFile &Obj) { // First, collect all weak and common symbols. We need to know if stronger // definitions occur elsewhere. - JITSymbolResolver::LookupFlagsResult SymbolFlags; + JITSymbolResolver::LookupSet ResponsibilitySet; { JITSymbolResolver::LookupSet Symbols; for (auto &Sym : Obj.symbols()) { @@ -218,10 +218,10 @@ RuntimeDyldImpl::loadObjectImpl(const object::ObjectFile &Obj) { } } - if (auto FlagsResultOrErr = Resolver.lookupFlags(Symbols)) - SymbolFlags = std::move(*FlagsResultOrErr); + if (auto ResultOrErr = Resolver.getResponsibilitySet(Symbols)) + ResponsibilitySet = std::move(*ResultOrErr); else - return FlagsResultOrErr.takeError(); + return ResultOrErr.takeError(); } // Parse symbols @@ -259,29 +259,26 @@ RuntimeDyldImpl::loadObjectImpl(const object::ObjectFile &Obj) { // strong. if (JITSymFlags->isWeak() || JITSymFlags->isCommon()) { // First check whether there's already a definition in this instance. - // FIXME: Override existing weak definitions with strong ones. if (GlobalSymbolTable.count(Name)) continue; - // Then check whether we found flags for an existing symbol during the - // flags lookup earlier. - auto FlagsI = SymbolFlags.find(Name); - if (FlagsI == SymbolFlags.end() || - (JITSymFlags->isWeak() && !FlagsI->second.isStrong()) || - (JITSymFlags->isCommon() && FlagsI->second.isCommon())) { - if (JITSymFlags->isWeak()) - *JITSymFlags &= ~JITSymbolFlags::Weak; - if (JITSymFlags->isCommon()) { - *JITSymFlags &= ~JITSymbolFlags::Common; - uint32_t Align = I->getAlignment(); - uint64_t Size = I->getCommonSize(); - if (!CommonAlign) - CommonAlign = Align; - CommonSize = alignTo(CommonSize, Align) + Size; - CommonSymbolsToAllocate.push_back(*I); - } - } else + // If we're not responsible for this symbol, skip it. + if (!ResponsibilitySet.count(Name)) continue; + + // Otherwise update the flags on the symbol to make this definition + // strong. + if (JITSymFlags->isWeak()) + *JITSymFlags &= ~JITSymbolFlags::Weak; + if (JITSymFlags->isCommon()) { + *JITSymFlags &= ~JITSymbolFlags::Common; + uint32_t Align = I->getAlignment(); + uint64_t Size = I->getCommonSize(); + if (!CommonAlign) + CommonAlign = Align; + CommonSize = alignTo(CommonSize, Align) + Size; + CommonSymbolsToAllocate.push_back(*I); + } } if (Flags & SymbolRef::SF_Absolute && diff --git a/test/ExecutionEngine/OrcLazy/Inputs/obj-weak-non-materialization-1.ll b/test/ExecutionEngine/OrcLazy/Inputs/obj-weak-non-materialization-1.ll new file mode 100644 index 00000000000..24594390f45 --- /dev/null +++ b/test/ExecutionEngine/OrcLazy/Inputs/obj-weak-non-materialization-1.ll @@ -0,0 +1 @@ +@X = weak global i32 0, align 4 diff --git a/test/ExecutionEngine/OrcLazy/Inputs/obj-weak-non-materialization-2.ll b/test/ExecutionEngine/OrcLazy/Inputs/obj-weak-non-materialization-2.ll new file mode 100644 index 00000000000..3262017012c --- /dev/null +++ b/test/ExecutionEngine/OrcLazy/Inputs/obj-weak-non-materialization-2.ll @@ -0,0 +1,7 @@ +@X = weak global i32 1, align 4 + +define void @foo() { +entry: + ret void +} + diff --git a/test/ExecutionEngine/OrcLazy/weak-non-materialization.ll b/test/ExecutionEngine/OrcLazy/weak-non-materialization.ll new file mode 100644 index 00000000000..c20de813123 --- /dev/null +++ b/test/ExecutionEngine/OrcLazy/weak-non-materialization.ll @@ -0,0 +1,17 @@ +; RUN: llc -filetype=obj -o %t1.o %p/Inputs/obj-weak-non-materialization-1.ll +; RUN: llc -filetype=obj -o %t2.o %p/Inputs/obj-weak-non-materialization-2.ll +; RUN: lli -jit-kind=orc-lazy -extra-object %t1.o -extra-object %t2.o %s +; +; Check that %t1.o's version of the weak symbol X is used, even though %t2.o is +; materialized first. + +@X = external global i32 + +declare void @foo() + +define i32 @main(i32 %argc, i8** %argv) { +entry: + call void @foo() + %0 = load i32, i32* @X + ret i32 %0 +} diff --git a/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp b/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp index 97b1ab5aee8..24c5378e2d1 100644 --- a/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp +++ b/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp @@ -531,6 +531,57 @@ TEST_F(CoreAPIsStandardTest, AddAndMaterializeLazySymbol) { EXPECT_TRUE(OnReadyRun) << "OnReady was not run"; } +TEST_F(CoreAPIsStandardTest, TestBasicWeakSymbolMaterialization) { + // Test that weak symbols are materialized correctly when we look them up. + BarSym.setFlags(static_cast(BarSym.getFlags() | + JITSymbolFlags::Weak)); + + bool BarMaterialized = false; + auto MU1 = llvm::make_unique( + SymbolFlagsMap({{Foo, FooSym.getFlags()}, {Bar, BarSym.getFlags()}}), + [&](MaterializationResponsibility R) { + R.resolve(SymbolMap({{Foo, FooSym}, {Bar, BarSym}})), R.emit(); + BarMaterialized = true; + }); + + bool DuplicateBarDiscarded = false; + auto MU2 = llvm::make_unique( + SymbolFlagsMap({{Bar, BarSym.getFlags()}}), + [&](MaterializationResponsibility R) { + ADD_FAILURE() << "Attempt to materialize Bar from the wrong unit"; + R.failMaterialization(); + }, + [&](const JITDylib &JD, SymbolStringPtr Name) { + EXPECT_EQ(Name, Bar) << "Expected \"Bar\" to be discarded"; + DuplicateBarDiscarded = true; + }); + + cantFail(JD.define(MU1)); + cantFail(JD.define(MU2)); + + bool OnResolvedRun = false; + bool OnReadyRun = false; + + auto OnResolution = [&](Expected Result) { + cantFail(std::move(Result)); + OnResolvedRun = true; + }; + + auto OnReady = [&](Error Err) { + cantFail(std::move(Err)); + OnReadyRun = true; + }; + + ES.lookup({&JD}, {Bar}, std::move(OnResolution), std::move(OnReady), + NoDependenciesToRegister); + + EXPECT_TRUE(OnResolvedRun) << "OnResolved not run"; + EXPECT_TRUE(OnReadyRun) << "OnReady not run"; + EXPECT_TRUE(BarMaterialized) << "Bar was not materialized at all"; + EXPECT_TRUE(DuplicateBarDiscarded) + << "Duplicate bar definition not discarded"; +} + TEST_F(CoreAPIsStandardTest, DefineMaterializingSymbol) { bool ExpectNoMoreMaterialization = false; ES.setDispatchMaterialization( diff --git a/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp b/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp index 05fe921ee96..ec8ef5641f7 100644 --- a/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp +++ b/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp @@ -19,26 +19,31 @@ class LegacyAPIsStandardTest : public CoreAPIsBasedStandardTest {}; namespace { TEST_F(LegacyAPIsStandardTest, TestLambdaSymbolResolver) { + BarSym.setFlags(static_cast(BarSym.getFlags() | + JITSymbolFlags::Weak)); + cantFail(JD.define(absoluteSymbols({{Foo, FooSym}, {Bar, BarSym}}))); auto Resolver = createSymbolResolver( - [&](const SymbolNameSet &Symbols) { return JD.lookupFlags(Symbols); }, + [&](const SymbolNameSet &Symbols) { + auto FlagsMap = JD.lookupFlags(Symbols); + llvm::dbgs() << "FlagsMap is " << FlagsMap << "\n"; + SymbolNameSet Result; + for (auto &KV : FlagsMap) + if (!KV.second.isStrong()) + Result.insert(KV.first); + return Result; + }, [&](std::shared_ptr Q, SymbolNameSet Symbols) { return JD.legacyLookup(std::move(Q), Symbols); }); - SymbolNameSet Symbols({Foo, Bar, Baz}); + auto RS = Resolver->getResponsibilitySet(SymbolNameSet({Bar, Baz})); - SymbolFlagsMap SymbolFlags = Resolver->lookupFlags(Symbols); - - EXPECT_EQ(SymbolFlags.size(), 2U) - << "lookupFlags returned the wrong number of results"; - EXPECT_EQ(SymbolFlags.count(Foo), 1U) << "Missing lookupFlags result for foo"; - EXPECT_EQ(SymbolFlags.count(Bar), 1U) << "Missing lookupFlags result for bar"; - EXPECT_EQ(SymbolFlags[Foo], FooSym.getFlags()) - << "Incorrect lookupFlags result for Foo"; - EXPECT_EQ(SymbolFlags[Bar], BarSym.getFlags()) - << "Incorrect lookupFlags result for Bar"; + EXPECT_EQ(RS.size(), 1U) + << "getResponsibilitySet returned the wrong number of results"; + EXPECT_EQ(RS.count(Bar), 1U) + << "getResponsibilitySet result incorrect. Should be {'bar'}"; bool OnResolvedRun = false; @@ -59,68 +64,22 @@ TEST_F(LegacyAPIsStandardTest, TestLambdaSymbolResolver) { auto Q = std::make_shared(SymbolNameSet({Foo, Bar}), OnResolved, OnReady); - auto Unresolved = Resolver->lookup(std::move(Q), Symbols); + auto Unresolved = + Resolver->lookup(std::move(Q), SymbolNameSet({Foo, Bar, Baz})); EXPECT_EQ(Unresolved.size(), 1U) << "Expected one unresolved symbol"; EXPECT_EQ(Unresolved.count(Baz), 1U) << "Expected baz to not be resolved"; EXPECT_TRUE(OnResolvedRun) << "OnResolved was never run"; } -TEST(LegacyAPIInteropTest, QueryAgainstJITDylib) { - - ExecutionSession ES(std::make_shared()); - auto Foo = ES.getSymbolStringPool().intern("foo"); - - auto &JD = ES.createJITDylib("JD"); - JITEvaluatedSymbol FooSym(0xdeadbeef, JITSymbolFlags::Exported); - cantFail(JD.define(absoluteSymbols({{Foo, FooSym}}))); - - auto LookupFlags = [&](const SymbolNameSet &Names) { - return JD.lookupFlags(Names); - }; - - auto Lookup = [&](std::shared_ptr Query, - SymbolNameSet Symbols) { - return JD.legacyLookup(std::move(Query), Symbols); - }; - - auto UnderlyingResolver = - createSymbolResolver(std::move(LookupFlags), std::move(Lookup)); - JITSymbolResolverAdapter Resolver(ES, *UnderlyingResolver, nullptr); - - JITSymbolResolver::LookupSet Names{StringRef("foo")}; - - auto LFR = Resolver.lookupFlags(Names); - EXPECT_TRUE(!!LFR) << "lookupFlags failed"; - EXPECT_EQ(LFR->size(), 1U) - << "lookupFlags returned the wrong number of results"; - EXPECT_EQ(LFR->count(*Foo), 1U) - << "lookupFlags did not contain a result for 'foo'"; - EXPECT_EQ((*LFR)[*Foo], FooSym.getFlags()) - << "lookupFlags contained the wrong result for 'foo'"; - - auto LR = Resolver.lookup(Names); - EXPECT_TRUE(!!LR) << "lookup failed"; - EXPECT_EQ(LR->size(), 1U) << "lookup returned the wrong number of results"; - EXPECT_EQ(LR->count(*Foo), 1U) << "lookup did not contain a result for 'foo'"; - EXPECT_EQ((*LR)[*Foo].getFlags(), FooSym.getFlags()) - << "lookup returned the wrong result for flags of 'foo'"; - EXPECT_EQ((*LR)[*Foo].getAddress(), FooSym.getAddress()) - << "lookup returned the wrong result for address of 'foo'"; -} - -TEST(LegacyAPIInteropTset, LegacyLookupHelpersFn) { - constexpr JITTargetAddress FooAddr = 0xdeadbeef; - JITSymbolFlags FooFlags = JITSymbolFlags::Exported; - +TEST_F(LegacyAPIsStandardTest, LegacyLookupHelpersFn) { bool BarMaterialized = false; - constexpr JITTargetAddress BarAddr = 0xcafef00d; - JITSymbolFlags BarFlags = static_cast( - JITSymbolFlags::Exported | JITSymbolFlags::Weak); + BarSym.setFlags(static_cast(BarSym.getFlags() | + JITSymbolFlags::Weak)); auto LegacyLookup = [&](const std::string &Name) -> JITSymbol { if (Name == "foo") - return {FooAddr, FooFlags}; + return FooSym; if (Name == "bar") { auto BarMaterializer = [&]() -> Expected { @@ -128,27 +87,18 @@ TEST(LegacyAPIInteropTset, LegacyLookupHelpersFn) { return BarAddr; }; - return {BarMaterializer, BarFlags}; + return {BarMaterializer, BarSym.getFlags()}; } return nullptr; }; - ExecutionSession ES; - auto Foo = ES.getSymbolStringPool().intern("foo"); - auto Bar = ES.getSymbolStringPool().intern("bar"); - auto Baz = ES.getSymbolStringPool().intern("baz"); + auto RS = + getResponsibilitySetWithLegacyFn(SymbolNameSet({Bar, Baz}), LegacyLookup); - SymbolNameSet Symbols({Foo, Bar, Baz}); - - auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup); - - EXPECT_TRUE(!!SymbolFlags) << "Expected lookupFlagsWithLegacyFn to succeed"; - EXPECT_EQ(SymbolFlags->size(), 2U) << "Wrong number of flags returned"; - EXPECT_EQ(SymbolFlags->count(Foo), 1U) << "Flags for foo missing"; - EXPECT_EQ(SymbolFlags->count(Bar), 1U) << "Flags for foo missing"; - EXPECT_EQ((*SymbolFlags)[Foo], FooFlags) << "Wrong flags for foo"; - EXPECT_EQ((*SymbolFlags)[Bar], BarFlags) << "Wrong flags for foo"; + EXPECT_TRUE(!!RS) << "Expected getResponsibilitySetWithLegacyFn to succeed"; + EXPECT_EQ(RS->size(), 1U) << "Wrong number of symbols returned"; + EXPECT_EQ(RS->count(Bar), 1U) << "Incorrect responsibility set returned"; EXPECT_FALSE(BarMaterialized) << "lookupFlags should not have materialized bar"; @@ -162,9 +112,11 @@ TEST(LegacyAPIInteropTset, LegacyLookupHelpersFn) { EXPECT_EQ(Result->count(Foo), 1U) << "Result for foo missing"; EXPECT_EQ(Result->count(Bar), 1U) << "Result for bar missing"; EXPECT_EQ((*Result)[Foo].getAddress(), FooAddr) << "Wrong address for foo"; - EXPECT_EQ((*Result)[Foo].getFlags(), FooFlags) << "Wrong flags for foo"; + EXPECT_EQ((*Result)[Foo].getFlags(), FooSym.getFlags()) + << "Wrong flags for foo"; EXPECT_EQ((*Result)[Bar].getAddress(), BarAddr) << "Wrong address for bar"; - EXPECT_EQ((*Result)[Bar].getFlags(), BarFlags) << "Wrong flags for bar"; + EXPECT_EQ((*Result)[Bar].getFlags(), BarSym.getFlags()) + << "Wrong flags for bar"; }; auto OnReady = [&](Error Err) { EXPECT_FALSE(!!Err) << "Finalization unexpectedly failed"; @@ -172,7 +124,8 @@ TEST(LegacyAPIInteropTset, LegacyLookupHelpersFn) { }; AsynchronousSymbolQuery Q({Foo, Bar}, OnResolved, OnReady); - auto Unresolved = lookupWithLegacyFn(ES, Q, Symbols, LegacyLookup); + auto Unresolved = + lookupWithLegacyFn(ES, Q, SymbolNameSet({Foo, Bar, Baz}), LegacyLookup); EXPECT_TRUE(OnResolvedRun) << "OnResolved was not run"; EXPECT_TRUE(OnReadyRun) << "OnReady was not run"; diff --git a/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp b/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp index 420631c36ad..94b771f2640 100644 --- a/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp +++ b/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp @@ -185,7 +185,8 @@ TEST_F(RTDyldObjectLinkingLayerExecutionTest, NoDuplicateFinalization) { Resolvers[K2] = createSymbolResolver( [&](const SymbolNameSet &Symbols) { - return cantFail(lookupFlagsWithLegacyFn(Symbols, LegacyLookup)); + return cantFail( + getResponsibilitySetWithLegacyFn(Symbols, LegacyLookup)); }, [&](std::shared_ptr Query, const SymbolNameSet &Symbols) {