diff --git a/include/llvm/ExecutionEngine/Orc/Core.h b/include/llvm/ExecutionEngine/Orc/Core.h index ae826912d62..42bcffd36b2 100644 --- a/include/llvm/ExecutionEngine/Orc/Core.h +++ b/include/llvm/ExecutionEngine/Orc/Core.h @@ -216,6 +216,18 @@ public: add(Name, Flags); } + /// Construct a SymbolLookupSet from DenseMap keys. + template + static SymbolLookupSet + fromMapKeys(const DenseMap &M, + SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) { + SymbolLookupSet Result; + Result.Symbols.reserve(M.size()); + for (const auto &KV : M) + Result.add(KV.first, Flags); + return Result; + } + /// Add an element to the set. The client is responsible for checking that /// duplicates are not added. SymbolLookupSet & diff --git a/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h b/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h index 7969a8398c9..566637e1044 100644 --- a/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h +++ b/include/llvm/ExecutionEngine/Orc/ExecutorProcessControl.h @@ -24,6 +24,7 @@ #include "llvm/Support/MSVCErrorWorkarounds.h" #include +#include #include namespace llvm { @@ -32,6 +33,19 @@ namespace orc { /// ExecutorProcessControl supports interaction with a JIT target process. class ExecutorProcessControl { public: + /// Sender to return the result of a WrapperFunction executed in the JIT. + using SendResultFunction = + unique_function; + + /// An asynchronous wrapper-function. + using AsyncWrapperFunction = unique_function; + + /// A map associating tag names with asynchronous wrapper function + /// implementations in the JIT. + using WrapperFunctionAssociationMap = + DenseMap; + /// APIs for manipulating memory in the target process. class MemoryAccess { public: @@ -138,14 +152,91 @@ public: virtual Expected runAsMain(JITTargetAddress MainFnAddr, ArrayRef Args) = 0; - /// Run a wrapper function in the executor. + /// Run a wrapper function in the executor (async version). + /// + /// The wrapper function should be callable as: /// /// \code{.cpp} /// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size); /// \endcode{.cpp} /// - virtual Expected - runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef ArgBuffer) = 0; + /// The given OnComplete function will be called to return the result. + virtual void runWrapperAsync(SendResultFunction OnComplete, + JITTargetAddress WrapperFnAddr, + ArrayRef ArgBuffer) = 0; + + /// Run a wrapper function in the executor. The wrapper function should be + /// callable as: + /// + /// \code{.cpp} + /// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size); + /// \endcode{.cpp} + shared::WrapperFunctionResult runWrapper(JITTargetAddress WrapperFnAddr, + ArrayRef ArgBuffer) { + std::promise RP; + auto RF = RP.get_future(); + runWrapperAsync( + [&](shared::WrapperFunctionResult R) { RP.set_value(std::move(R)); }, + WrapperFnAddr, ArgBuffer); + return RF.get(); + } + + /// Run a wrapper function using SPS to serialize the arguments and + /// deserialize the results. + template + void runSPSWrapperAsync(SendResultT &&SendResult, + JITTargetAddress WrapperFnAddr, + const ArgTs &...Args) { + shared::WrapperFunction::callAsync( + [this, WrapperFnAddr](SendResultFunction SendResult, + const char *ArgData, size_t ArgSize) { + runWrapperAsync(std::move(SendResult), WrapperFnAddr, + ArrayRef(ArgData, ArgSize)); + }, + std::move(SendResult), Args...); + } + + /// Run a wrapper function using SPS to serialize the arguments and + /// deserialize the results. + template + Error runSPSWrapper(JITTargetAddress WrapperFnAddr, RetT &RetVal, + const ArgTs &...Args) { + return shared::WrapperFunction::call( + [this, WrapperFnAddr](const char *ArgData, size_t ArgSize) { + return runWrapper(WrapperFnAddr, ArrayRef(ArgData, ArgSize)); + }, + RetVal, Args...); + } + + /// Wrap a handler that takes concrete argument types (and a sender for a + /// concrete return type) to produce an AsyncWrapperFunction. Uses SPS to + /// unpack the arguments and pack the result. + /// + /// This function is usually used when building association maps. + template + static AsyncWrapperFunction wrapAsyncWithSPS(HandlerT &&H) { + return [H = std::forward(H)](SendResultFunction SendResult, + const char *ArgData, + size_t ArgSize) mutable { + shared::WrapperFunction::handleAsync(ArgData, ArgSize, H, + std::move(SendResult)); + }; + } + + /// For each symbol name, associate the AsyncWrapperFunction implementation + /// value with the address of that symbol. + /// + /// Symbols will be looked up using LookupKind::Static, + /// JITDylibLookupFlags::MatchAllSymbols (hidden tags will be found), and + /// LookupFlags::WeaklyReferencedSymbol (missing tags will not cause an + /// error, the implementations will simply be dropped). + Error associateJITSideWrapperFunctions(JITDylib &JD, + WrapperFunctionAssociationMap WFs); + + /// Run a registered jit-side wrapper function. + void runJITSideWrapperFunction(SendResultFunction SendResult, + JITTargetAddress TagAddr, + ArrayRef ArgBuffer); /// Disconnect from the target process. /// @@ -161,6 +252,9 @@ protected: unsigned PageSize = 0; MemoryAccess *MemAccess = nullptr; jitlink::JITLinkMemoryManager *MemMgr = nullptr; + + std::mutex TagToFuncMapMutex; + DenseMap> TagToFunc; }; /// Call a wrapper function via ExecutorProcessControl::runWrapper. @@ -168,8 +262,8 @@ class EPCCaller { public: EPCCaller(ExecutorProcessControl &EPC, JITTargetAddress WrapperFnAddr) : EPC(EPC), WrapperFnAddr(WrapperFnAddr) {} - Expected operator()(const char *ArgData, - size_t ArgSize) const { + shared::WrapperFunctionResult operator()(const char *ArgData, + size_t ArgSize) const { return EPC.runWrapper(WrapperFnAddr, ArrayRef(ArgData, ArgSize)); } @@ -202,8 +296,9 @@ public: Expected runAsMain(JITTargetAddress MainFnAddr, ArrayRef Args) override; - Expected - runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef ArgBuffer) override; + void runWrapperAsync(SendResultFunction OnComplete, + JITTargetAddress WrapperFnAddr, + ArrayRef ArgBuffer) override; Error disconnect() override; diff --git a/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h b/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h index 0b5ee262bb7..69e37f9af9e 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRPCExecutorProcessControl.h @@ -354,9 +354,9 @@ public: return Result; } - Expected - runWrapper(JITTargetAddress WrapperFnAddr, - ArrayRef ArgBuffer) override { + void runWrapperAsync(SendResultFunction OnComplete, + JITTargetAddress WrapperFnAddr, + ArrayRef ArgBuffer) override { DEBUG_WITH_TYPE("orc", { dbgs() << "Running as wrapper function " << formatv("{0:x16}", WrapperFnAddr) << " with " @@ -366,7 +366,11 @@ public: WrapperFnAddr, ArrayRef(reinterpret_cast(ArgBuffer.data()), ArgBuffer.size())); - return Result; + + if (!Result) + OnComplete(shared::WrapperFunctionResult::createOutOfBandError( + toString(Result.takeError()))); + OnComplete(std::move(*Result)); } Error closeConnection(OnCloseConnectionFunction OnCloseConnection) { diff --git a/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h b/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h index 0fc8af77023..ceaea1d2b20 100644 --- a/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h +++ b/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h @@ -172,17 +172,16 @@ private: namespace detail { template -Expected +WrapperFunctionResult serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { WrapperFunctionResult Result; char *DataPtr = WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...)); SPSOutputBuffer OB(DataPtr, Result.size()); if (!SPSArgListT::serialize(OB, Args...)) - return make_error( - "Error serializing arguments to blob in call", - inconvertibleErrorCode()); - return std::move(Result); + return WrapperFunctionResult::createOutOfBandError( + "Error serializing arguments to blob in call"); + return Result; } template class WrapperFunctionHandlerCaller { @@ -230,12 +229,8 @@ public: auto HandlerResult = WrapperFunctionHandlerCaller::call( std::forward(H), Args, ArgIndices{}); - if (auto Result = ResultSerializer::serialize( - std::move(HandlerResult))) - return std::move(*Result); - else - return WrapperFunctionResult::createOutOfBandError( - toString(Result.takeError())); + return ResultSerializer::serialize( + std::move(HandlerResult)); } private: @@ -247,10 +242,10 @@ private: } }; -// Map function references to function types. +// Map function pointers to function types. template class ResultSerializer, typename... SPSTagTs> -class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; @@ -271,9 +266,87 @@ class WrapperFunctionHandlerHelper {}; +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionAsyncHandlerHelper + : public WrapperFunctionAsyncHandlerHelper< + decltype(&std::remove_reference_t::operator()), + ResultSerializer, SPSTagTs...> {}; + +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionAsyncHandlerHelper { +public: + using ArgTuple = std::tuple...>; + using ArgIndices = std::make_index_sequence::value>; + + template + static void applyAsync(HandlerT &&H, + SendWrapperFunctionResultT &&SendWrapperFunctionResult, + const char *ArgData, size_t ArgSize) { + ArgTuple Args; + if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) { + SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError( + "Could not deserialize arguments for wrapper function call")); + return; + } + + auto SendResult = + [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable { + using ResultT = decltype(Result); + SendWFR(ResultSerializer::serialize(std::move(Result))); + }; + + callAsync(std::forward(H), std::move(SendResult), Args, + ArgIndices{}); + } + +private: + template + static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, + std::index_sequence) { + SPSInputBuffer IB(ArgData, ArgSize); + return SPSArgList::deserialize(IB, std::get(Args)...); + } + + template + static void callAsync(HandlerT &&H, + SerializeAndSendResultT &&SerializeAndSendResult, + ArgTupleT &Args, std::index_sequence) { + return std::forward(H)(std::move(SerializeAndSendResult), + std::get(Args)...); + } +}; + +// Map function pointers to function types. +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionAsyncHandlerHelper + : public WrapperFunctionAsyncHandlerHelper {}; + +// Map non-const member function types to function types. +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionAsyncHandlerHelper + : public WrapperFunctionAsyncHandlerHelper {}; + +// Map const member function types to function types. +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionAsyncHandlerHelper + : public WrapperFunctionAsyncHandlerHelper {}; + template class ResultSerializer { public: - static Expected serialize(RetT Result) { + static WrapperFunctionResult serialize(RetT Result) { return serializeViaSPSToWrapperFunctionResult>( Result); } @@ -281,7 +354,7 @@ public: template class ResultSerializer { public: - static Expected serialize(Error Err) { + static WrapperFunctionResult serialize(Error Err) { return serializeViaSPSToWrapperFunctionResult>( toSPSSerializable(std::move(Err))); } @@ -290,7 +363,7 @@ public: template class ResultSerializer> { public: - static Expected serialize(Expected E) { + static WrapperFunctionResult serialize(Expected E) { return serializeViaSPSToWrapperFunctionResult>( toSPSSerializable(std::move(E))); } @@ -298,6 +371,7 @@ public: template class ResultDeserializer { public: + static RetT makeValue() { return RetT(); } static void makeSafe(RetT &Result) {} static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { @@ -312,6 +386,7 @@ public: template <> class ResultDeserializer { public: + static Error makeValue() { return Error::success(); } static void makeSafe(Error &Err) { cantFail(std::move(Err)); } static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { @@ -329,6 +404,7 @@ public: template class ResultDeserializer, Expected> { public: + static Expected makeValue() { return T(); } static void makeSafe(Expected &E) { cantFail(E.takeError()); } static Error deserialize(Expected &E, const char *ArgData, @@ -344,6 +420,10 @@ public: } }; +template class AsyncCallResultHelper { + // Did you forget to use Error / Expected in your handler? +}; + } // end namespace detail template class WrapperFunction; @@ -355,7 +435,7 @@ private: using ResultSerializer = detail::ResultSerializer; public: - /// Call a wrapper function. Callere should be callable as + /// Call a wrapper function. Caller should be callable as /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize); template static Error call(const CallerFn &Caller, RetT &Result, @@ -369,18 +449,56 @@ public: auto ArgBuffer = detail::serializeViaSPSToWrapperFunctionResult>( Args...); - if (!ArgBuffer) - return ArgBuffer.takeError(); + if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) + return make_error(ErrMsg, inconvertibleErrorCode()); - Expected ResultBuffer = - Caller(ArgBuffer->data(), ArgBuffer->size()); - if (!ResultBuffer) - return ResultBuffer.takeError(); - if (auto ErrMsg = ResultBuffer->getOutOfBandError()) + WrapperFunctionResult ResultBuffer = + Caller(ArgBuffer.data(), ArgBuffer.size()); + if (auto ErrMsg = ResultBuffer.getOutOfBandError()) return make_error(ErrMsg, inconvertibleErrorCode()); return detail::ResultDeserializer::deserialize( - Result, ResultBuffer->data(), ResultBuffer->size()); + Result, ResultBuffer.data(), ResultBuffer.size()); + } + + /// Call an async wrapper function. + /// Caller should be callable as + /// void Fn(unique_function SendResult, + /// WrapperFunctionResult ArgBuffer); + template + static void callAsync(AsyncCallerFn &&Caller, + SendDeserializedResultFn &&SendDeserializedResult, + const ArgTs &...Args) { + using RetT = typename std::tuple_element< + 1, typename detail::WrapperFunctionHandlerHelper< + std::remove_reference_t, + ResultSerializer, SPSRetTagT>::ArgTuple>::type; + + auto ArgBuffer = + detail::serializeViaSPSToWrapperFunctionResult>( + Args...); + if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) { + SendDeserializedResult( + make_error(ErrMsg, inconvertibleErrorCode()), + detail::ResultDeserializer::makeValue()); + return; + } + + auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)]( + WrapperFunctionResult R) { + RetT RetVal = detail::ResultDeserializer::makeValue(); + detail::ResultDeserializer::makeSafe(RetVal); + + SPSInputBuffer IB(R.data(), R.size()); + if (auto Err = detail::ResultDeserializer::deserialize( + RetVal, R.data(), R.size())) + SDR(std::move(Err), std::move(RetVal)); + + SDR(Error::success(), std::move(RetVal)); + }; + + Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size()); } /// Handle a call to a wrapper function. @@ -388,11 +506,21 @@ public: static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, HandlerT &&Handler) { using WFHH = - detail::WrapperFunctionHandlerHelper; + detail::WrapperFunctionHandlerHelper, + ResultSerializer, SPSTagTs...>; return WFHH::apply(std::forward(Handler), ArgData, ArgSize); } + /// Handle a call to an async wrapper function. + template + static void handleAsync(const char *ArgData, size_t ArgSize, + HandlerT &&Handler, SendResultT &&SendResult) { + using WFAHH = detail::WrapperFunctionAsyncHandlerHelper< + std::remove_reference_t, ResultSerializer, SPSTagTs...>; + WFAHH::applyAsync(std::forward(Handler), + std::forward(SendResult), ArgData, ArgSize); + } + private: template static const T &makeSerializable(const T &Value) { return Value; @@ -411,6 +539,7 @@ private: template class WrapperFunction : private WrapperFunction { + public: template static Error call(const CallerFn &Caller, const ArgTs &...Args) { @@ -419,6 +548,7 @@ public: } using WrapperFunction::handle; + using WrapperFunction::handleAsync; }; } // end namespace shared diff --git a/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp b/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp index f8bd74eabc9..12fa42ccdef 100644 --- a/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp +++ b/lib/ExecutionEngine/Orc/ExecutorProcessControl.cpp @@ -10,11 +10,10 @@ #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Host.h" #include "llvm/Support/Process.h" -#include - namespace llvm { namespace orc { @@ -22,6 +21,56 @@ ExecutorProcessControl::MemoryAccess::~MemoryAccess() {} ExecutorProcessControl::~ExecutorProcessControl() {} +Error ExecutorProcessControl::associateJITSideWrapperFunctions( + JITDylib &JD, WrapperFunctionAssociationMap WFs) { + + // Look up tag addresses. + auto &ES = JD.getExecutionSession(); + auto TagAddrs = + ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, + SymbolLookupSet::fromMapKeys( + WFs, SymbolLookupFlags::WeaklyReferencedSymbol)); + if (!TagAddrs) + return TagAddrs.takeError(); + + // Associate tag addresses with implementations. + std::lock_guard Lock(TagToFuncMapMutex); + for (auto &KV : *TagAddrs) { + auto TagAddr = KV.second.getAddress(); + if (TagToFunc.count(TagAddr)) + return make_error("Tag " + formatv("{0:x16}", TagAddr) + + " (for " + *KV.first + + ") already registered", + inconvertibleErrorCode()); + auto I = WFs.find(KV.first); + assert(I != WFs.end() && I->second && + "AsyncWrapperFunction implementation missing"); + TagToFunc[KV.second.getAddress()] = + std::make_shared(std::move(I->second)); + } + return Error::success(); +} + +void ExecutorProcessControl::runJITSideWrapperFunction( + SendResultFunction SendResult, JITTargetAddress TagAddr, + ArrayRef ArgBuffer) { + + std::shared_ptr F; + { + std::lock_guard Lock(TagToFuncMapMutex); + auto I = TagToFunc.find(TagAddr); + if (I != TagToFunc.end()) + F = I->second; + } + + if (F) + (*F)(std::move(SendResult), ArgBuffer.data(), ArgBuffer.size()); + else + SendResult(shared::WrapperFunctionResult::createOutOfBandError( + ("No function registered for tag " + formatv("{0:x16}", TagAddr)) + .str())); +} + SelfExecutorProcessControl::SelfExecutorProcessControl( std::shared_ptr SSP, Triple TargetTriple, unsigned PageSize, std::unique_ptr MemMgr) @@ -102,13 +151,13 @@ SelfExecutorProcessControl::runAsMain(JITTargetAddress MainFnAddr, return orc::runAsMain(jitTargetAddressToFunction(MainFnAddr), Args); } -Expected -SelfExecutorProcessControl::runWrapper(JITTargetAddress WrapperFnAddr, - ArrayRef ArgBuffer) { - using WrapperFnTy = shared::detail::CWrapperFunctionResult (*)( - const char *Data, uint64_t Size); +void SelfExecutorProcessControl::runWrapperAsync(SendResultFunction SendResult, + JITTargetAddress WrapperFnAddr, + ArrayRef ArgBuffer) { + using WrapperFnTy = + shared::detail::CWrapperFunctionResult (*)(const char *Data, size_t Size); auto *WrapperFn = jitTargetAddressToFunction(WrapperFnAddr); - return WrapperFn(ArgBuffer.data(), ArgBuffer.size()); + SendResult(WrapperFn(ArgBuffer.data(), ArgBuffer.size())); } Error SelfExecutorProcessControl::disconnect() { return Error::success(); } diff --git a/unittests/ExecutionEngine/Orc/CMakeLists.txt b/unittests/ExecutionEngine/Orc/CMakeLists.txt index b1cfd18e5d4..b544cfa1864 100644 --- a/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -16,6 +16,7 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(OrcJITTests CoreAPIsTest.cpp + ExecutorProcessControlTest.cpp IndirectionUtilsTest.cpp JITTargetMachineBuilderTest.cpp LazyCallThroughAndReexportsTest.cpp diff --git a/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp b/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp new file mode 100644 index 00000000000..23096c86f4d --- /dev/null +++ b/unittests/ExecutionEngine/Orc/ExecutorProcessControlTest.cpp @@ -0,0 +1,105 @@ +//===- ExecutorProcessControlTest.cpp - Test ExecutorProcessControl utils -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/Support/MSVCErrorWorkarounds.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +#include + +using namespace llvm; +using namespace llvm::orc; +using namespace llvm::orc::shared; + +static llvm::orc::shared::detail::CWrapperFunctionResult +addWrapper(const char *ArgData, size_t ArgSize) { + return WrapperFunction::handle( + ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; }) + .release(); +} + +static void addAsyncWrapper(unique_function SendResult, + int32_t X, int32_t Y) { + SendResult(X + Y); +} + +TEST(ExecutorProcessControl, RunWrapperTemplate) { + auto EPC = cantFail( + SelfExecutorProcessControl::Create(std::make_shared())); + + int32_t Result; + EXPECT_THAT_ERROR(EPC->runSPSWrapper( + pointerToJITTargetAddress(addWrapper), Result, 2, 3), + Succeeded()); + EXPECT_EQ(Result, 5); +} + +TEST(ExecutorProcessControl, RunWrapperAsyncTemplate) { + auto EPC = cantFail( + SelfExecutorProcessControl::Create(std::make_shared())); + + std::promise> RP; + using Sig = int32_t(int32_t, int32_t); + EPC->runSPSWrapperAsync( + [&](Error SerializationErr, int32_t R) { + if (SerializationErr) + RP.set_value(std::move(SerializationErr)); + RP.set_value(std::move(R)); + }, + pointerToJITTargetAddress(addWrapper), 2, 3); + Expected Result = RP.get_future().get(); + EXPECT_THAT_EXPECTED(Result, HasValue(5)); +} + +TEST(ExecutorProcessControl, RegisterAsyncHandlerAndRun) { + + constexpr JITTargetAddress AddAsyncTagAddr = 0x01; + + auto EPC = cantFail( + SelfExecutorProcessControl::Create(std::make_shared())); + ExecutionSession ES(EPC->getSymbolStringPool()); + auto &JD = ES.createBareJITDylib("JD"); + + auto AddAsyncTag = ES.intern("addAsync_tag"); + cantFail(JD.define(absoluteSymbols( + {{AddAsyncTag, + JITEvaluatedSymbol(AddAsyncTagAddr, JITSymbolFlags::Exported)}}))); + + ExecutorProcessControl::WrapperFunctionAssociationMap Associations; + + Associations[AddAsyncTag] = + EPC->wrapAsyncWithSPS(addAsyncWrapper); + + cantFail(EPC->associateJITSideWrapperFunctions(JD, std::move(Associations))); + + std::promise RP; + auto RF = RP.get_future(); + + using ArgSerialization = SPSArgList; + size_t ArgBufferSize = ArgSerialization::size(1, 2); + WrapperFunctionResult ArgBuffer; + char *ArgBufferData = + WrapperFunctionResult::allocate(ArgBuffer, ArgBufferSize); + SPSOutputBuffer OB(ArgBufferData, ArgBufferSize); + EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2)); + + EPC->runJITSideWrapperFunction( + [&](WrapperFunctionResult ResultBuffer) { + int32_t Result; + SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size()); + EXPECT_TRUE(SPSArgList::deserialize(IB, Result)); + RP.set_value(Result); + }, + AddAsyncTagAddr, ArrayRef(ArgBuffer.data(), ArgBuffer.size())); + + EXPECT_EQ(RF.get(), (int32_t)3); + + cantFail(ES.endSession()); +} diff --git a/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp b/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp index 1f177b4c2d1..42051836506 100644 --- a/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp @@ -7,8 +7,11 @@ //===----------------------------------------------------------------------===// #include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h" +#include "llvm/ADT/FunctionExtras.h" #include "gtest/gtest.h" +#include + using namespace llvm; using namespace llvm::orc::shared; @@ -65,13 +68,54 @@ static WrapperFunctionResult addWrapper(const char *ArgData, size_t ArgSize) { ArgData, ArgSize, [](int32_t X, int32_t Y) -> int32_t { return X + Y; }); } -TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) { +TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleVoid) { EXPECT_FALSE(!!WrapperFunction::call(voidNoopWrapper)); } -TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandle) { +TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleRet) { int32_t Result; EXPECT_FALSE(!!WrapperFunction::call( addWrapper, Result, 1, 2)); EXPECT_EQ(Result, (int32_t)3); } + +static void voidNoopAsync(unique_function SendResult) { + SendResult(SPSEmpty()); +} + +static WrapperFunctionResult voidNoopAsyncWrapper(const char *ArgData, + size_t ArgSize) { + std::promise RP; + auto RF = RP.get_future(); + + WrapperFunction::handleAsync( + ArgData, ArgSize, voidNoopAsync, + [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); }); + + return RF.get(); +} + +static WrapperFunctionResult addAsyncWrapper(const char *ArgData, + size_t ArgSize) { + std::promise RP; + auto RF = RP.get_future(); + + WrapperFunction::handleAsync( + ArgData, ArgSize, + [](unique_function SendResult, int32_t X, int32_t Y) { + SendResult(X + Y); + }, + [&](WrapperFunctionResult R) { RP.set_value(std::move(R)); }); + return RF.get(); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncVoid) { + EXPECT_FALSE(!!WrapperFunction::call(voidNoopAsyncWrapper)); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncRet) { + int32_t Result; + EXPECT_FALSE(!!WrapperFunction::call( + addAsyncWrapper, Result, 1, 2)); + EXPECT_EQ(Result, (int32_t)3); +}