1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 10:42:39 +01:00

[ORC] Add wrapper-function support methods to ExecutorProcessControl.

Adds support for both synchronous and asynchronous calls to wrapper functions
using SPS (Simple Packed Serialization). Also adds support for wrapping
functions on the JIT side in SPS-based wrappers that can be called from the
executor.

These new methods simplify calls between the JIT and Executor, and will be used
in upcoming ORC runtime patches to enable communication between ORC and the
runtime.
This commit is contained in:
Lang Hames 2021-06-19 17:36:47 +10:00
parent c930f37268
commit 6567b76038
8 changed files with 488 additions and 48 deletions

View File

@ -216,6 +216,18 @@ public:
add(Name, Flags);
}
/// Construct a SymbolLookupSet from DenseMap keys.
template <typename KeyT>
static SymbolLookupSet
fromMapKeys(const DenseMap<SymbolStringPtr, KeyT> &M,
SymbolLookupFlags Flags = SymbolLookupFlags::RequiredSymbol) {
SymbolLookupSet Result;
Result.Symbols.reserve(M.size());
for (const auto &KV : M)
Result.add(KV.first, Flags);
return Result;
}
/// Add an element to the set. The client is responsible for checking that
/// duplicates are not added.
SymbolLookupSet &

View File

@ -24,6 +24,7 @@
#include "llvm/Support/MSVCErrorWorkarounds.h"
#include <future>
#include <mutex>
#include <vector>
namespace llvm {
@ -32,6 +33,19 @@ namespace orc {
/// ExecutorProcessControl supports interaction with a JIT target process.
class ExecutorProcessControl {
public:
/// Sender to return the result of a WrapperFunction executed in the JIT.
using SendResultFunction =
unique_function<void(shared::WrapperFunctionResult)>;
/// An asynchronous wrapper-function.
using AsyncWrapperFunction = unique_function<void(
SendResultFunction SendResult, const char *ArgData, size_t ArgSize)>;
/// A map associating tag names with asynchronous wrapper function
/// implementations in the JIT.
using WrapperFunctionAssociationMap =
DenseMap<SymbolStringPtr, AsyncWrapperFunction>;
/// APIs for manipulating memory in the target process.
class MemoryAccess {
public:
@ -138,14 +152,91 @@ public:
virtual Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
ArrayRef<std::string> Args) = 0;
/// Run a wrapper function in the executor.
/// Run a wrapper function in the executor (async version).
///
/// The wrapper function should be callable as:
///
/// \code{.cpp}
/// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
/// \endcode{.cpp}
///
virtual Expected<shared::WrapperFunctionResult>
runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) = 0;
/// The given OnComplete function will be called to return the result.
virtual void runWrapperAsync(SendResultFunction OnComplete,
JITTargetAddress WrapperFnAddr,
ArrayRef<char> ArgBuffer) = 0;
/// Run a wrapper function in the executor. The wrapper function should be
/// callable as:
///
/// \code{.cpp}
/// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size);
/// \endcode{.cpp}
shared::WrapperFunctionResult runWrapper(JITTargetAddress WrapperFnAddr,
ArrayRef<char> ArgBuffer) {
std::promise<shared::WrapperFunctionResult> RP;
auto RF = RP.get_future();
runWrapperAsync(
[&](shared::WrapperFunctionResult R) { RP.set_value(std::move(R)); },
WrapperFnAddr, ArgBuffer);
return RF.get();
}
/// Run a wrapper function using SPS to serialize the arguments and
/// deserialize the results.
template <typename SPSSignature, typename SendResultT, typename... ArgTs>
void runSPSWrapperAsync(SendResultT &&SendResult,
JITTargetAddress WrapperFnAddr,
const ArgTs &...Args) {
shared::WrapperFunction<SPSSignature>::callAsync(
[this, WrapperFnAddr](SendResultFunction SendResult,
const char *ArgData, size_t ArgSize) {
runWrapperAsync(std::move(SendResult), WrapperFnAddr,
ArrayRef<char>(ArgData, ArgSize));
},
std::move(SendResult), Args...);
}
/// Run a wrapper function using SPS to serialize the arguments and
/// deserialize the results.
template <typename SPSSignature, typename RetT, typename... ArgTs>
Error runSPSWrapper(JITTargetAddress WrapperFnAddr, RetT &RetVal,
const ArgTs &...Args) {
return shared::WrapperFunction<SPSSignature>::call(
[this, WrapperFnAddr](const char *ArgData, size_t ArgSize) {
return runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
},
RetVal, Args...);
}
/// Wrap a handler that takes concrete argument types (and a sender for a
/// concrete return type) to produce an AsyncWrapperFunction. Uses SPS to
/// unpack the arguments and pack the result.
///
/// This function is usually used when building association maps.
template <typename SPSSignature, typename HandlerT>
static AsyncWrapperFunction wrapAsyncWithSPS(HandlerT &&H) {
return [H = std::forward<HandlerT>(H)](SendResultFunction SendResult,
const char *ArgData,
size_t ArgSize) mutable {
shared::WrapperFunction<SPSSignature>::handleAsync(ArgData, ArgSize, H,
std::move(SendResult));
};
}
/// For each symbol name, associate the AsyncWrapperFunction implementation
/// value with the address of that symbol.
///
/// Symbols will be looked up using LookupKind::Static,
/// JITDylibLookupFlags::MatchAllSymbols (hidden tags will be found), and
/// LookupFlags::WeaklyReferencedSymbol (missing tags will not cause an
/// error, the implementations will simply be dropped).
Error associateJITSideWrapperFunctions(JITDylib &JD,
WrapperFunctionAssociationMap WFs);
/// Run a registered jit-side wrapper function.
void runJITSideWrapperFunction(SendResultFunction SendResult,
JITTargetAddress TagAddr,
ArrayRef<char> ArgBuffer);
/// Disconnect from the target process.
///
@ -161,6 +252,9 @@ protected:
unsigned PageSize = 0;
MemoryAccess *MemAccess = nullptr;
jitlink::JITLinkMemoryManager *MemMgr = nullptr;
std::mutex TagToFuncMapMutex;
DenseMap<JITTargetAddress, std::shared_ptr<AsyncWrapperFunction>> TagToFunc;
};
/// Call a wrapper function via ExecutorProcessControl::runWrapper.
@ -168,7 +262,7 @@ class EPCCaller {
public:
EPCCaller(ExecutorProcessControl &EPC, JITTargetAddress WrapperFnAddr)
: EPC(EPC), WrapperFnAddr(WrapperFnAddr) {}
Expected<shared::WrapperFunctionResult> operator()(const char *ArgData,
shared::WrapperFunctionResult operator()(const char *ArgData,
size_t ArgSize) const {
return EPC.runWrapper(WrapperFnAddr, ArrayRef<char>(ArgData, ArgSize));
}
@ -202,8 +296,9 @@ public:
Expected<int32_t> runAsMain(JITTargetAddress MainFnAddr,
ArrayRef<std::string> Args) override;
Expected<shared::WrapperFunctionResult>
runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef<char> ArgBuffer) override;
void runWrapperAsync(SendResultFunction OnComplete,
JITTargetAddress WrapperFnAddr,
ArrayRef<char> ArgBuffer) override;
Error disconnect() override;

View File

@ -354,8 +354,8 @@ public:
return Result;
}
Expected<shared::WrapperFunctionResult>
runWrapper(JITTargetAddress WrapperFnAddr,
void runWrapperAsync(SendResultFunction OnComplete,
JITTargetAddress WrapperFnAddr,
ArrayRef<char> ArgBuffer) override {
DEBUG_WITH_TYPE("orc", {
dbgs() << "Running as wrapper function "
@ -366,7 +366,11 @@ public:
WrapperFnAddr,
ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(ArgBuffer.data()),
ArgBuffer.size()));
return Result;
if (!Result)
OnComplete(shared::WrapperFunctionResult::createOutOfBandError(
toString(Result.takeError())));
OnComplete(std::move(*Result));
}
Error closeConnection(OnCloseConnectionFunction OnCloseConnection) {

View File

@ -172,17 +172,16 @@ private:
namespace detail {
template <typename SPSArgListT, typename... ArgTs>
Expected<WrapperFunctionResult>
WrapperFunctionResult
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
WrapperFunctionResult Result;
char *DataPtr =
WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
SPSOutputBuffer OB(DataPtr, Result.size());
if (!SPSArgListT::serialize(OB, Args...))
return make_error<StringError>(
"Error serializing arguments to blob in call",
inconvertibleErrorCode());
return std::move(Result);
return WrapperFunctionResult::createOutOfBandError(
"Error serializing arguments to blob in call");
return Result;
}
template <typename RetT> class WrapperFunctionHandlerCaller {
@ -230,12 +229,8 @@ public:
auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
std::forward<HandlerT>(H), Args, ArgIndices{});
if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
std::move(HandlerResult)))
return std::move(*Result);
else
return WrapperFunctionResult::createOutOfBandError(
toString(Result.takeError()));
return ResultSerializer<decltype(HandlerResult)>::serialize(
std::move(HandlerResult));
}
private:
@ -247,10 +242,10 @@ private:
}
};
// Map function references to function types.
// Map function pointers to function types.
template <typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
SPSTagTs...>
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
@ -271,9 +266,87 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename WrapperFunctionImplT,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper
: public WrapperFunctionAsyncHandlerHelper<
decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
ResultSerializer, SPSTagTs...> {};
template <typename RetT, typename SendResultT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
ResultSerializer, SPSTagTs...> {
public:
using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
template <typename HandlerT, typename SendWrapperFunctionResultT>
static void applyAsync(HandlerT &&H,
SendWrapperFunctionResultT &&SendWrapperFunctionResult,
const char *ArgData, size_t ArgSize) {
ArgTuple Args;
if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
"Could not deserialize arguments for wrapper function call"));
return;
}
auto SendResult =
[SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
using ResultT = decltype(Result);
SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
};
callAsync(std::forward<HandlerT>(H), std::move(SendResult), Args,
ArgIndices{});
}
private:
template <std::size_t... I>
static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
std::index_sequence<I...>) {
SPSInputBuffer IB(ArgData, ArgSize);
return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
}
template <typename HandlerT, typename SerializeAndSendResultT,
typename ArgTupleT, std::size_t... I>
static void callAsync(HandlerT &&H,
SerializeAndSendResultT &&SerializeAndSendResult,
ArgTupleT &Args, std::index_sequence<I...>) {
return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
std::get<I>(Args)...);
}
};
// Map function pointers to function types.
template <typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
SPSTagTs...>
: public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
// Map non-const member function types to function types.
template <typename ClassT, typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
ResultSerializer, SPSTagTs...>
: public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
// Map const member function types to function types.
template <typename ClassT, typename RetT, typename... ArgTs,
template <typename> class ResultSerializer, typename... SPSTagTs>
class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
ResultSerializer, SPSTagTs...>
: public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
SPSTagTs...> {};
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
public:
static Expected<WrapperFunctionResult> serialize(RetT Result) {
static WrapperFunctionResult serialize(RetT Result) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
Result);
}
@ -281,7 +354,7 @@ public:
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
public:
static Expected<WrapperFunctionResult> serialize(Error Err) {
static WrapperFunctionResult serialize(Error Err) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(Err)));
}
@ -290,7 +363,7 @@ public:
template <typename SPSRetTagT, typename T>
class ResultSerializer<SPSRetTagT, Expected<T>> {
public:
static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
static WrapperFunctionResult serialize(Expected<T> E) {
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
toSPSSerializable(std::move(E)));
}
@ -298,6 +371,7 @@ public:
template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
public:
static RetT makeValue() { return RetT(); }
static void makeSafe(RetT &Result) {}
static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
@ -312,6 +386,7 @@ public:
template <> class ResultDeserializer<SPSError, Error> {
public:
static Error makeValue() { return Error::success(); }
static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
@ -329,6 +404,7 @@ public:
template <typename SPSTagT, typename T>
class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
public:
static Expected<T> makeValue() { return T(); }
static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
static Error deserialize(Expected<T> &E, const char *ArgData,
@ -344,6 +420,10 @@ public:
}
};
template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
// Did you forget to use Error / Expected in your handler?
};
} // end namespace detail
template <typename SPSSignature> class WrapperFunction;
@ -355,7 +435,7 @@ private:
using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
public:
/// Call a wrapper function. Callere should be callable as
/// Call a wrapper function. Caller should be callable as
/// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize);
template <typename CallerFn, typename RetT, typename... ArgTs>
static Error call(const CallerFn &Caller, RetT &Result,
@ -369,18 +449,56 @@ public:
auto ArgBuffer =
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
Args...);
if (!ArgBuffer)
return ArgBuffer.takeError();
if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
Expected<WrapperFunctionResult> ResultBuffer =
Caller(ArgBuffer->data(), ArgBuffer->size());
if (!ResultBuffer)
return ResultBuffer.takeError();
if (auto ErrMsg = ResultBuffer->getOutOfBandError())
WrapperFunctionResult ResultBuffer =
Caller(ArgBuffer.data(), ArgBuffer.size());
if (auto ErrMsg = ResultBuffer.getOutOfBandError())
return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
Result, ResultBuffer->data(), ResultBuffer->size());
Result, ResultBuffer.data(), ResultBuffer.size());
}
/// Call an async wrapper function.
/// Caller should be callable as
/// void Fn(unique_function<void(WrapperFunctionResult)> SendResult,
/// WrapperFunctionResult ArgBuffer);
template <typename AsyncCallerFn, typename SendDeserializedResultFn,
typename... ArgTs>
static void callAsync(AsyncCallerFn &&Caller,
SendDeserializedResultFn &&SendDeserializedResult,
const ArgTs &...Args) {
using RetT = typename std::tuple_element<
1, typename detail::WrapperFunctionHandlerHelper<
std::remove_reference_t<SendDeserializedResultFn>,
ResultSerializer, SPSRetTagT>::ArgTuple>::type;
auto ArgBuffer =
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
Args...);
if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
SendDeserializedResult(
make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
return;
}
auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
WrapperFunctionResult R) {
RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
SPSInputBuffer IB(R.data(), R.size());
if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
RetVal, R.data(), R.size()))
SDR(std::move(Err), std::move(RetVal));
SDR(Error::success(), std::move(RetVal));
};
Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
}
/// Handle a call to a wrapper function.
@ -388,11 +506,21 @@ public:
static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
HandlerT &&Handler) {
using WFHH =
detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
SPSTagTs...>;
detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
ResultSerializer, SPSTagTs...>;
return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
}
/// Handle a call to an async wrapper function.
template <typename HandlerT, typename SendResultT>
static void handleAsync(const char *ArgData, size_t ArgSize,
HandlerT &&Handler, SendResultT &&SendResult) {
using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
WFAHH::applyAsync(std::forward<HandlerT>(Handler),
std::forward<SendResultT>(SendResult), ArgData, ArgSize);
}
private:
template <typename T> static const T &makeSerializable(const T &Value) {
return Value;
@ -411,6 +539,7 @@ private:
template <typename... SPSTagTs>
class WrapperFunction<void(SPSTagTs...)>
: private WrapperFunction<SPSEmpty(SPSTagTs...)> {
public:
template <typename CallerFn, typename... ArgTs>
static Error call(const CallerFn &Caller, const ArgTs &...Args) {
@ -419,6 +548,7 @@ public:
}
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
};
} // end namespace shared

View File

@ -10,11 +10,10 @@
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Host.h"
#include "llvm/Support/Process.h"
#include <mutex>
namespace llvm {
namespace orc {
@ -22,6 +21,56 @@ ExecutorProcessControl::MemoryAccess::~MemoryAccess() {}
ExecutorProcessControl::~ExecutorProcessControl() {}
Error ExecutorProcessControl::associateJITSideWrapperFunctions(
JITDylib &JD, WrapperFunctionAssociationMap WFs) {
// Look up tag addresses.
auto &ES = JD.getExecutionSession();
auto TagAddrs =
ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}},
SymbolLookupSet::fromMapKeys(
WFs, SymbolLookupFlags::WeaklyReferencedSymbol));
if (!TagAddrs)
return TagAddrs.takeError();
// Associate tag addresses with implementations.
std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
for (auto &KV : *TagAddrs) {
auto TagAddr = KV.second.getAddress();
if (TagToFunc.count(TagAddr))
return make_error<StringError>("Tag " + formatv("{0:x16}", TagAddr) +
" (for " + *KV.first +
") already registered",
inconvertibleErrorCode());
auto I = WFs.find(KV.first);
assert(I != WFs.end() && I->second &&
"AsyncWrapperFunction implementation missing");
TagToFunc[KV.second.getAddress()] =
std::make_shared<AsyncWrapperFunction>(std::move(I->second));
}
return Error::success();
}
void ExecutorProcessControl::runJITSideWrapperFunction(
SendResultFunction SendResult, JITTargetAddress TagAddr,
ArrayRef<char> ArgBuffer) {
std::shared_ptr<AsyncWrapperFunction> F;
{
std::lock_guard<std::mutex> Lock(TagToFuncMapMutex);
auto I = TagToFunc.find(TagAddr);
if (I != TagToFunc.end())
F = I->second;
}
if (F)
(*F)(std::move(SendResult), ArgBuffer.data(), ArgBuffer.size());
else
SendResult(shared::WrapperFunctionResult::createOutOfBandError(
("No function registered for tag " + formatv("{0:x16}", TagAddr))
.str()));
}
SelfExecutorProcessControl::SelfExecutorProcessControl(
std::shared_ptr<SymbolStringPool> SSP, Triple TargetTriple,
unsigned PageSize, std::unique_ptr<jitlink::JITLinkMemoryManager> MemMgr)
@ -102,13 +151,13 @@ SelfExecutorProcessControl::runAsMain(JITTargetAddress MainFnAddr,
return orc::runAsMain(jitTargetAddressToFunction<MainTy>(MainFnAddr), Args);
}
Expected<shared::WrapperFunctionResult>
SelfExecutorProcessControl::runWrapper(JITTargetAddress WrapperFnAddr,
void SelfExecutorProcessControl::runWrapperAsync(SendResultFunction SendResult,
JITTargetAddress WrapperFnAddr,
ArrayRef<char> ArgBuffer) {
using WrapperFnTy = shared::detail::CWrapperFunctionResult (*)(
const char *Data, uint64_t Size);
using WrapperFnTy =
shared::detail::CWrapperFunctionResult (*)(const char *Data, size_t Size);
auto *WrapperFn = jitTargetAddressToFunction<WrapperFnTy>(WrapperFnAddr);
return WrapperFn(ArgBuffer.data(), ArgBuffer.size());
SendResult(WrapperFn(ArgBuffer.data(), ArgBuffer.size()));
}
Error SelfExecutorProcessControl::disconnect() { return Error::success(); }

View File

@ -16,6 +16,7 @@ set(LLVM_LINK_COMPONENTS
add_llvm_unittest(OrcJITTests
CoreAPIsTest.cpp
ExecutorProcessControlTest.cpp
IndirectionUtilsTest.cpp
JITTargetMachineBuilderTest.cpp
LazyCallThroughAndReexportsTest.cpp

View File

@ -0,0 +1,105 @@
//===- ExecutorProcessControlTest.cpp - Test ExecutorProcessControl utils -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/Support/MSVCErrorWorkarounds.h"
#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"
#include <future>
using namespace llvm;
using namespace llvm::orc;
using namespace llvm::orc::shared;
static llvm::orc::shared::detail::CWrapperFunctionResult
addWrapper(const char *ArgData, size_t ArgSize) {
return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
ArgData, ArgSize, [](int32_t X, int32_t Y) { return X + Y; })
.release();
}
static void addAsyncWrapper(unique_function<void(int32_t)> SendResult,
int32_t X, int32_t Y) {
SendResult(X + Y);
}
TEST(ExecutorProcessControl, RunWrapperTemplate) {
auto EPC = cantFail(
SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
int32_t Result;
EXPECT_THAT_ERROR(EPC->runSPSWrapper<int32_t(int32_t, int32_t)>(
pointerToJITTargetAddress(addWrapper), Result, 2, 3),
Succeeded());
EXPECT_EQ(Result, 5);
}
TEST(ExecutorProcessControl, RunWrapperAsyncTemplate) {
auto EPC = cantFail(
SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
std::promise<MSVCPExpected<int32_t>> RP;
using Sig = int32_t(int32_t, int32_t);
EPC->runSPSWrapperAsync<Sig>(
[&](Error SerializationErr, int32_t R) {
if (SerializationErr)
RP.set_value(std::move(SerializationErr));
RP.set_value(std::move(R));
},
pointerToJITTargetAddress(addWrapper), 2, 3);
Expected<int32_t> Result = RP.get_future().get();
EXPECT_THAT_EXPECTED(Result, HasValue(5));
}
TEST(ExecutorProcessControl, RegisterAsyncHandlerAndRun) {
constexpr JITTargetAddress AddAsyncTagAddr = 0x01;
auto EPC = cantFail(
SelfExecutorProcessControl::Create(std::make_shared<SymbolStringPool>()));
ExecutionSession ES(EPC->getSymbolStringPool());
auto &JD = ES.createBareJITDylib("JD");
auto AddAsyncTag = ES.intern("addAsync_tag");
cantFail(JD.define(absoluteSymbols(
{{AddAsyncTag,
JITEvaluatedSymbol(AddAsyncTagAddr, JITSymbolFlags::Exported)}})));
ExecutorProcessControl::WrapperFunctionAssociationMap Associations;
Associations[AddAsyncTag] =
EPC->wrapAsyncWithSPS<int32_t(int32_t, int32_t)>(addAsyncWrapper);
cantFail(EPC->associateJITSideWrapperFunctions(JD, std::move(Associations)));
std::promise<int32_t> RP;
auto RF = RP.get_future();
using ArgSerialization = SPSArgList<int32_t, int32_t>;
size_t ArgBufferSize = ArgSerialization::size(1, 2);
WrapperFunctionResult ArgBuffer;
char *ArgBufferData =
WrapperFunctionResult::allocate(ArgBuffer, ArgBufferSize);
SPSOutputBuffer OB(ArgBufferData, ArgBufferSize);
EXPECT_TRUE(ArgSerialization::serialize(OB, 1, 2));
EPC->runJITSideWrapperFunction(
[&](WrapperFunctionResult ResultBuffer) {
int32_t Result;
SPSInputBuffer IB(ResultBuffer.data(), ResultBuffer.size());
EXPECT_TRUE(SPSArgList<int32_t>::deserialize(IB, Result));
RP.set_value(Result);
},
AddAsyncTagAddr, ArrayRef<char>(ArgBuffer.data(), ArgBuffer.size()));
EXPECT_EQ(RF.get(), (int32_t)3);
cantFail(ES.endSession());
}

View File

@ -7,8 +7,11 @@
//===----------------------------------------------------------------------===//
#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h"
#include "llvm/ADT/FunctionExtras.h"
#include "gtest/gtest.h"
#include <future>
using namespace llvm;
using namespace llvm::orc::shared;
@ -65,13 +68,54 @@ static WrapperFunctionResult addWrapper(const char *ArgData, size_t ArgSize) {
ArgData, ArgSize, [](int32_t X, int32_t Y) -> int32_t { return X + Y; });
}
TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleVoid) {
EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopWrapper));
}
TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandle) {
TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleRet) {
int32_t Result;
EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
addWrapper, Result, 1, 2));
EXPECT_EQ(Result, (int32_t)3);
}
static void voidNoopAsync(unique_function<void(SPSEmpty)> SendResult) {
SendResult(SPSEmpty());
}
static WrapperFunctionResult voidNoopAsyncWrapper(const char *ArgData,
size_t ArgSize) {
std::promise<WrapperFunctionResult> RP;
auto RF = RP.get_future();
WrapperFunction<void()>::handleAsync(
ArgData, ArgSize, voidNoopAsync,
[&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
return RF.get();
}
static WrapperFunctionResult addAsyncWrapper(const char *ArgData,
size_t ArgSize) {
std::promise<WrapperFunctionResult> RP;
auto RF = RP.get_future();
WrapperFunction<int32_t(int32_t, int32_t)>::handleAsync(
ArgData, ArgSize,
[](unique_function<void(int32_t)> SendResult, int32_t X, int32_t Y) {
SendResult(X + Y);
},
[&](WrapperFunctionResult R) { RP.set_value(std::move(R)); });
return RF.get();
}
TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncVoid) {
EXPECT_FALSE(!!WrapperFunction<void()>::call(voidNoopAsyncWrapper));
}
TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandleAsyncRet) {
int32_t Result;
EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
addAsyncWrapper, Result, 1, 2));
EXPECT_EQ(Result, (int32_t)3);
}