diff --git a/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h b/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h index a8aa4279911..6ccc90e60b8 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRPCTargetProcessControl.h @@ -354,7 +354,7 @@ public: return Result; } - Expected + Expected runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef ArgBuffer) override { DEBUG_WITH_TYPE("orc", { @@ -364,7 +364,6 @@ public: }); auto Result = EP.template callB(WrapperFnAddr, ArgBuffer); - // dbgs() << "Returned from runWrapper...\n"; return Result; } diff --git a/include/llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h b/include/llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h new file mode 100644 index 00000000000..0db40e2969a --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h @@ -0,0 +1,565 @@ +//===---- SimplePackedSerialization.h - simple serialization ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The behavior of the utilities in this header must be synchronized with the +// behavior of the utilities in +// compiler-rt/lib/orc/simple_packed_serialization.h. +// +// The Simple Packed Serialization (SPS) utilities are used to generate +// argument and return buffers for wrapper functions using the following +// serialization scheme: +// +// Primitives (signed types should be two's complement): +// bool, char, int8_t, uint8_t -- 8-bit (0=false, 1=true) +// int16_t, uint16_t -- 16-bit little endian +// int32_t, uint32_t -- 32-bit little endian +// int64_t, int64_t -- 64-bit little endian +// +// Sequence: +// Serialized as the sequence length (as a uint64_t) followed by the +// serialization of each of the elements without padding. +// +// Tuple: +// Serialized as each of the element types from T1 to TN without padding. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_SIMPLEPACKEDSERIALIZATION_H +#define LLVM_EXECUTIONENGINE_ORC_SHARED_SIMPLEPACKEDSERIALIZATION_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/SwapByteOrder.h" + +#include +#include +#include +#include +#include + +namespace llvm { +namespace orc { +namespace shared { + +/// Output char buffer with overflow check. +class SPSOutputBuffer { +public: + SPSOutputBuffer(char *Buffer, size_t Remaining) + : Buffer(Buffer), Remaining(Remaining) {} + bool write(const char *Data, size_t Size) { + if (Size > Remaining) + return false; + memcpy(Buffer, Data, Size); + Buffer += Size; + Remaining -= Size; + return true; + } + +private: + char *Buffer = nullptr; + size_t Remaining = 0; +}; + +/// Input char buffer with underflow check. +class SPSInputBuffer { +public: + SPSInputBuffer() = default; + SPSInputBuffer(const char *Buffer, size_t Remaining) + : Buffer(Buffer), Remaining(Remaining) {} + bool read(char *Data, size_t Size) { + if (Size > Remaining) + return false; + memcpy(Data, Buffer, Size); + Buffer += Size; + Remaining -= Size; + return true; + } + + const char *data() const { return Buffer; } + bool skip(size_t Size) { + if (Size > Remaining) + return false; + Remaining -= Size; + return true; + } + +private: + const char *Buffer = nullptr; + size_t Remaining = 0; +}; + +/// Specialize to describe how to serialize/deserialize to/from the given +/// concrete type. +template +class SPSSerializationTraits; + +/// A utility class for serializing to a blob from a variadic list. +template class SPSArgList; + +// Empty list specialization for SPSArgList. +template <> class SPSArgList<> { +public: + static size_t size() { return 0; } + + static bool serialize(SPSOutputBuffer &OB) { return true; } + static bool deserialize(SPSInputBuffer &IB) { return true; } +}; + +// Non-empty list specialization for SPSArgList. +template +class SPSArgList { +public: + template + static size_t size(const ArgT &Arg, const ArgTs &...Args) { + return SPSSerializationTraits::size(Arg) + + SPSArgList::size(Args...); + } + + template + static bool serialize(SPSOutputBuffer &OB, const ArgT &Arg, + const ArgTs &...Args) { + return SPSSerializationTraits::serialize(OB, Arg) && + SPSArgList::serialize(OB, Args...); + } + + template + static bool deserialize(SPSInputBuffer &IB, ArgT &Arg, ArgTs &...Args) { + return SPSSerializationTraits::deserialize(IB, Arg) && + SPSArgList::deserialize(IB, Args...); + } +}; + +/// SPS serialization for integral types, bool, and char. +template +class SPSSerializationTraits< + SPSTagT, SPSTagT, + std::enable_if_t::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value>> { +public: + static size_t size(const SPSTagT &Value) { return sizeof(SPSTagT); } + + static bool serialize(SPSOutputBuffer &OB, const SPSTagT &Value) { + SPSTagT Tmp = Value; + if (sys::IsBigEndianHost) + sys::swapByteOrder(Tmp); + return OB.write(reinterpret_cast(&Tmp), sizeof(Tmp)); + } + + static bool deserialize(SPSInputBuffer &IB, SPSTagT &Value) { + SPSTagT Tmp; + if (!IB.read(reinterpret_cast(&Tmp), sizeof(Tmp))) + return false; + if (sys::IsBigEndianHost) + sys::swapByteOrder(Tmp); + Value = Tmp; + return true; + } +}; + +// Any empty placeholder suitable as a substitute for void when deserializing +class SPSEmpty {}; + +/// SPS tag type for target addresses. +/// +/// SPSTagTargetAddresses should be serialized as a uint64_t value. +class SPSTagTargetAddress; + +template <> +class SPSSerializationTraits + : public SPSSerializationTraits {}; + +/// SPS tag type for tuples. +/// +/// A blob tuple should be serialized by serializing each of the elements in +/// sequence. +template class SPSTuple { +public: + /// Convenience typedef of the corresponding arg list. + typedef SPSArgList AsArgList; +}; + +/// SPS tag type for sequences. +/// +/// SPSSequences should be serialized as a uint64_t sequence length, +/// followed by the serialization of each of the elements. +template class SPSSequence; + +/// SPS tag type for strings, which are equivalent to sequences of chars. +using SPSString = SPSSequence; + +/// SPS tag type for target addresseses. +class SPSTargetAddress {}; + +template <> +class SPSSerializationTraits + : public SPSSerializationTraits {}; + +/// SPS tag type for maps. +/// +/// SPS maps are just sequences of (Key, Value) tuples. +template +using SPSMap = SPSSequence>; + +/// Serialization for SPSEmpty type. +template <> class SPSSerializationTraits { +public: + static size_t size(const SPSEmpty &EP) { return 0; } + static bool serialize(SPSOutputBuffer &OB, const SPSEmpty &BE) { + return true; + } + static bool deserialize(SPSInputBuffer &IB, SPSEmpty &BE) { return true; } +}; + +/// Specialize this to implement 'trivial' sequence serialization for +/// a concrete sequence type. +/// +/// Trivial sequence serialization uses the sequence's 'size' member to get the +/// length of the sequence, and uses a range-based for loop to iterate over the +/// elements. +/// +/// Specializing this template class means that you do not need to provide a +/// specialization of SPSSerializationTraits for your type. +template +class TrivialSPSSequenceSerialization { +public: + static constexpr bool available = false; +}; + +/// Specialize this to implement 'trivial' sequence deserialization for +/// a concrete sequence type. +/// +/// Trivial deserialization calls a static 'reserve(SequenceT&)' method on your +/// specialization (you must implement this) to reserve space, and then calls +/// a static 'append(SequenceT&, ElementT&) method to append each of the +/// deserialized elements. +/// +/// Specializing this template class means that you do not need to provide a +/// specialization of SPSSerializationTraits for your type. +template +class TrivialSPSSequenceDeserialization { +public: + static constexpr bool available = false; +}; + +/// Trivial std::string -> SPSSequence serialization. +template <> class TrivialSPSSequenceSerialization { +public: + static constexpr bool available = true; +}; + +/// Trivial SPSSequence -> std::string deserialization. +template <> class TrivialSPSSequenceDeserialization { +public: + static constexpr bool available = true; + + using element_type = char; + + static void reserve(std::string &S, uint64_t Size) { S.reserve(Size); } + static bool append(std::string &S, char C) { + S.push_back(C); + return true; + } +}; + +/// Trivial std::vector -> SPSSequence serialization. +template +class TrivialSPSSequenceSerialization> { +public: + static constexpr bool available = true; +}; + +/// Trivial SPSSequence -> std::vector deserialization. +template +class TrivialSPSSequenceDeserialization> { +public: + static constexpr bool available = true; + + using element_type = typename std::vector::value_type; + + static void reserve(std::vector &V, uint64_t Size) { V.reserve(Size); } + static bool append(std::vector &V, T E) { + V.push_back(std::move(E)); + return true; + } +}; + +/// 'Trivial' sequence serialization: Sequence is serialized as a uint64_t size +/// followed by a for-earch loop over the elements of the sequence to serialize +/// each of them. +template +class SPSSerializationTraits, SequenceT, + std::enable_if_t::available>> { +public: + static size_t size(const SequenceT &S) { + size_t Size = SPSArgList::size(static_cast(S.size())); + for (const auto &E : S) + Size += SPSArgList::size(E); + return Size; + } + + static bool serialize(SPSOutputBuffer &OB, const SequenceT &S) { + if (!SPSArgList::serialize(OB, static_cast(S.size()))) + return false; + for (const auto &E : S) + if (!SPSArgList::serialize(OB, E)) + return false; + return true; + } + + static bool deserialize(SPSInputBuffer &IB, SequenceT &S) { + using TBSD = TrivialSPSSequenceDeserialization; + uint64_t Size; + if (!SPSArgList::deserialize(IB, Size)) + return false; + TBSD::reserve(S, Size); + for (size_t I = 0; I != Size; ++I) { + typename TBSD::element_type E; + if (!SPSArgList::deserialize(IB, E)) + return false; + if (!TBSD::append(S, std::move(E))) + return false; + } + return true; + } +}; + +/// SPSTuple serialization for std::pair. +template +class SPSSerializationTraits, std::pair> { +public: + static size_t size(const std::pair &P) { + return SPSArgList::size(P.first) + + SPSArgList::size(P.second); + } + + static bool serialize(SPSOutputBuffer &OB, const std::pair &P) { + return SPSArgList::serialize(OB, P.first) && + SPSArgList::serialize(OB, P.second); + } + + static bool deserialize(SPSInputBuffer &IB, std::pair &P) { + return SPSArgList::deserialize(IB, P.first) && + SPSArgList::deserialize(IB, P.second); + } +}; + +/// Serialization for StringRefs. +/// +/// Serialization is as for regular strings. Deserialization points directly +/// into the blob. +template <> class SPSSerializationTraits { +public: + static size_t size(const StringRef &S) { + return SPSArgList::size(static_cast(S.size())) + + S.size(); + } + + static bool serialize(SPSOutputBuffer &OB, StringRef S) { + if (!SPSArgList::serialize(OB, static_cast(S.size()))) + return false; + return OB.write(S.data(), S.size()); + } + + static bool deserialize(SPSInputBuffer &IB, StringRef &S) { + const char *Data = nullptr; + uint64_t Size; + if (!SPSArgList::deserialize(IB, Size)) + return false; + Data = IB.data(); + if (!IB.skip(Size)) + return false; + S = {Data, Size}; + return true; + } +}; + +/// SPS tag type for errors. +class SPSError; + +/// SPS tag type for expecteds, which are either a T or a string representing +/// an error. +template class SPSExpected; + +namespace detail { + +/// Helper type for serializing Errors. +/// +/// llvm::Errors are move-only, and not inspectable except by consuming them. +/// This makes them unsuitable for direct serialization via +/// SPSSerializationTraits, which needs to inspect values twice (once to +/// determine the amount of space to reserve, and then again to serialize). +/// +/// The SPSSerializableError type is a helper that can be +/// constructed from an llvm::Error, but inspected more than once. +struct SPSSerializableError { + bool HasError = false; + std::string ErrMsg; +}; + +/// Helper type for serializing Expecteds. +/// +/// See SPSSerializableError for more details. +/// +// FIXME: Use std::variant for storage once we have c++17. +template struct SPSSerializableExpected { + bool HasValue = false; + T Value{}; + std::string ErrMsg; +}; + +inline SPSSerializableError toSPSSerializable(Error Err) { + if (Err) + return {true, toString(std::move(Err))}; + return {false, {}}; +} + +inline Error fromSPSSerializable(SPSSerializableError BSE) { + if (BSE.HasError) + return make_error(BSE.ErrMsg, inconvertibleErrorCode()); + return Error::success(); +} + +template +SPSSerializableExpected toSPSSerializable(Expected E) { + if (E) + return {true, std::move(*E), {}}; + else + return {false, {}, toString(E.takeError())}; +} + +template +Expected fromSPSSerializable(SPSSerializableExpected BSE) { + if (BSE.HasValue) + return std::move(BSE.Value); + else + return make_error(BSE.ErrMsg, inconvertibleErrorCode()); +} + +} // end namespace detail + +/// Serialize to a SPSError from a detail::SPSSerializableError. +template <> +class SPSSerializationTraits { +public: + static size_t size(const detail::SPSSerializableError &BSE) { + size_t Size = SPSArgList::size(BSE.HasError); + if (BSE.HasError) + Size += SPSArgList::size(BSE.ErrMsg); + return Size; + } + + static bool serialize(SPSOutputBuffer &OB, + const detail::SPSSerializableError &BSE) { + if (!SPSArgList::serialize(OB, BSE.HasError)) + return false; + if (BSE.HasError) + if (!SPSArgList::serialize(OB, BSE.ErrMsg)) + return false; + return true; + } + + static bool deserialize(SPSInputBuffer &IB, + detail::SPSSerializableError &BSE) { + if (!SPSArgList::deserialize(IB, BSE.HasError)) + return false; + + if (!BSE.HasError) + return true; + + return SPSArgList::deserialize(IB, BSE.ErrMsg); + } +}; + +/// Serialize to a SPSExpected from a +/// detail::SPSSerializableExpected. +template +class SPSSerializationTraits, + detail::SPSSerializableExpected> { +public: + static size_t size(const detail::SPSSerializableExpected &BSE) { + size_t Size = SPSArgList::size(BSE.HasValue); + if (BSE.HasValue) + Size += SPSArgList::size(BSE.Value); + else + Size += SPSArgList::size(BSE.ErrMsg); + return Size; + } + + static bool serialize(SPSOutputBuffer &OB, + const detail::SPSSerializableExpected &BSE) { + if (!SPSArgList::serialize(OB, BSE.HasValue)) + return false; + + if (BSE.HasValue) + return SPSArgList::serialize(OB, BSE.Value); + + return SPSArgList::serialize(OB, BSE.ErrMsg); + } + + static bool deserialize(SPSInputBuffer &IB, + detail::SPSSerializableExpected &BSE) { + if (!SPSArgList::deserialize(IB, BSE.HasValue)) + return false; + + if (BSE.HasValue) + return SPSArgList::deserialize(IB, BSE.Value); + + return SPSArgList::deserialize(IB, BSE.ErrMsg); + } +}; + +/// Serialize to a SPSExpected from a detail::SPSSerializableError. +template +class SPSSerializationTraits, + detail::SPSSerializableError> { +public: + static size_t size(const detail::SPSSerializableError &BSE) { + assert(BSE.HasError && "Cannot serialize expected from a success value"); + return SPSArgList::size(false) + + SPSArgList::size(BSE.ErrMsg); + } + + static bool serialize(SPSOutputBuffer &OB, + const detail::SPSSerializableError &BSE) { + assert(BSE.HasError && "Cannot serialize expected from a success value"); + if (!SPSArgList::serialize(OB, false)) + return false; + return SPSArgList::serialize(OB, BSE.ErrMsg); + } +}; + +/// Serialize to a SPSExpected from a T. +template +class SPSSerializationTraits, T> { +public: + static size_t size(const T &Value) { + return SPSArgList::size(true) + SPSArgList::size(Value); + } + + static bool serialize(SPSOutputBuffer &OB, const T &Value) { + if (!SPSArgList::serialize(OB, true)) + return false; + return SPSArgList::serialize(Value); + } +}; + +} // end namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_SHARED_SIMPLEPACKEDSERIALIZATION_H diff --git a/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h b/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h index d01b3ef21f8..a44bcd4c806 100644 --- a/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h +++ b/include/llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h @@ -61,103 +61,6 @@ using DylibHandle = JITTargetAddress; using LookupResult = std::vector; -/// Either a uint8_t array or a uint8_t*. -union CWrapperFunctionResultData { - uint8_t Value[8]; - uint8_t *ValuePtr; -}; - -/// C ABI compatible wrapper function result. -/// -/// This can be safely returned from extern "C" functions, but should be used -/// to construct a WrapperFunctionResult for safety. -struct CWrapperFunctionResult { - uint64_t Size; - CWrapperFunctionResultData Data; - void (*Destroy)(CWrapperFunctionResultData Data, uint64_t Size); -}; - -/// C++ wrapper function result: Same as CWrapperFunctionResult but -/// auto-releases memory. -class WrapperFunctionResult { -public: - /// Create a default WrapperFunctionResult. - WrapperFunctionResult() { zeroInit(R); } - - /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This - /// instance takes ownership of the result object and will automatically - /// call the Destroy member upon destruction. - WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {} - - WrapperFunctionResult(const WrapperFunctionResult &) = delete; - WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; - - WrapperFunctionResult(WrapperFunctionResult &&Other) { - zeroInit(R); - std::swap(R, Other.R); - } - - WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { - CWrapperFunctionResult Tmp; - zeroInit(Tmp); - std::swap(Tmp, Other.R); - std::swap(R, Tmp); - return *this; - } - - ~WrapperFunctionResult() { - if (R.Destroy) - R.Destroy(R.Data, R.Size); - } - - /// Relinquish ownership of and return the CWrapperFunctionResult. - CWrapperFunctionResult release() { - CWrapperFunctionResult Tmp; - zeroInit(Tmp); - std::swap(R, Tmp); - return Tmp; - } - - /// Get an ArrayRef covering the data in the result. - ArrayRef getData() const { - if (R.Size <= 8) - return ArrayRef(R.Data.Value, R.Size); - return ArrayRef(R.Data.ValuePtr, R.Size); - } - - /// Create a WrapperFunctionResult from the given integer, provided its - /// size is no greater than 64 bits. - template ::value && - sizeof(T) <= sizeof(uint64_t)>> - static WrapperFunctionResult from(T Value) { - CWrapperFunctionResult R; - R.Size = sizeof(T); - memcpy(&R.Data.Value, Value, R.Size); - R.Destroy = nullptr; - return R; - } - - /// Create a WrapperFunctionResult from the given string. - static WrapperFunctionResult from(StringRef S); - - /// Always free Data.ValuePtr by calling free on it. - static void destroyWithFree(CWrapperFunctionResultData Data, uint64_t Size); - - /// Always free Data.ValuePtr by calling delete[] on it. - static void destroyWithDeleteArray(CWrapperFunctionResultData Data, - uint64_t Size); - -private: - static void zeroInit(CWrapperFunctionResult &R) { - R.Size = 0; - R.Data.ValuePtr = nullptr; - R.Destroy = nullptr; - } - - CWrapperFunctionResult R; -}; - } // end namespace tpctypes } // end namespace orc } // end namespace llvm diff --git a/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h b/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h new file mode 100644 index 00000000000..d975728b108 --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h @@ -0,0 +1,426 @@ +//===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A buffer for serialized results. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H + +#include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h" +#include "llvm/Support/Error.h" + +#include + +namespace llvm { +namespace orc { +namespace shared { + +namespace detail { + +// DO NOT USE DIRECTLY. +// Must be kept in-sync with compiler-rt/lib/orc/c-api.h. +union CWrapperFunctionResultDataUnion { + const char *ValuePtr; + char Value[sizeof(ValuePtr)]; +}; + +// DO NOT USE DIRECTLY. +// Must be kept in-sync with compiler-rt/lib/orc/c-api.h. +typedef struct { + CWrapperFunctionResultDataUnion Data; + size_t Size; +} CWrapperFunctionResult; + +} // end namespace detail + +/// C++ wrapper function result: Same as CWrapperFunctionResult but +/// auto-releases memory. +class WrapperFunctionResult { +public: + /// Create a default WrapperFunctionResult. + WrapperFunctionResult() { init(R); } + + /// Create a WrapperFunctionResult by taking ownership of a + /// detail::CWrapperFunctionResult. + /// + /// Warning: This should only be used by clients writing wrapper-function + /// caller utilities (like TargetProcessControl). + WrapperFunctionResult(detail::CWrapperFunctionResult R) : R(R) { + // Reset R. + init(R); + } + + WrapperFunctionResult(const WrapperFunctionResult &) = delete; + WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; + + WrapperFunctionResult(WrapperFunctionResult &&Other) { + init(R); + std::swap(R, Other.R); + } + + WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { + WrapperFunctionResult Tmp(std::move(Other)); + std::swap(R, Tmp.R); + return *this; + } + + ~WrapperFunctionResult() { + if ((R.Size > sizeof(R.Data.Value)) || + (R.Size == 0 && R.Data.ValuePtr != nullptr)) + free((void *)R.Data.ValuePtr); + } + + /// Release ownership of the contained detail::CWrapperFunctionResult. + /// Warning: Do not use -- this method will be removed in the future. It only + /// exists to temporarily support some code that will eventually be moved to + /// the ORC runtime. + detail::CWrapperFunctionResult release() { + detail::CWrapperFunctionResult Tmp; + init(Tmp); + std::swap(R, Tmp); + return Tmp; + } + + /// Get a pointer to the data contained in this instance. + const char *data() const { + assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && + "Cannot get data for out-of-band error value"); + return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value; + } + + /// Returns the size of the data contained in this instance. + size_t size() const { + assert((R.Size != 0 || R.Data.ValuePtr == nullptr) && + "Cannot get data for out-of-band error value"); + return R.Size; + } + + /// Returns true if this value is equivalent to a default-constructed + /// WrapperFunctionResult. + bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; } + + /// Create a WrapperFunctionResult with the given size and return a pointer + /// to the underlying memory. + static char *allocate(WrapperFunctionResult &WFR, size_t Size) { + // Reset. + WFR = WrapperFunctionResult(); + WFR.R.Size = Size; + char *DataPtr; + if (WFR.R.Size > sizeof(WFR.R.Data.Value)) { + DataPtr = (char *)malloc(WFR.R.Size); + WFR.R.Data.ValuePtr = DataPtr; + } else + DataPtr = WFR.R.Data.Value; + return DataPtr; + } + + /// Copy from the given char range. + static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { + WrapperFunctionResult WFR; + char *DataPtr = allocate(WFR, Size); + memcpy(DataPtr, Source, Size); + return WFR; + } + + /// Copy from the given null-terminated string (includes the null-terminator). + static WrapperFunctionResult copyFrom(const char *Source) { + return copyFrom(Source, strlen(Source) + 1); + } + + /// Copy from the given std::string (includes the null terminator). + static WrapperFunctionResult copyFrom(const std::string &Source) { + return copyFrom(Source.c_str()); + } + + /// Create an out-of-band error by copying the given string. + static WrapperFunctionResult createOutOfBandError(const char *Msg) { + // Reset. + WrapperFunctionResult WFR; + char *Tmp = (char *)malloc(strlen(Msg) + 1); + strcpy(Tmp, Msg); + WFR.R.Data.ValuePtr = Tmp; + return WFR; + } + + /// Create an out-of-band error by copying the given string. + static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { + return createOutOfBandError(Msg.c_str()); + } + + /// If this value is an out-of-band error then this returns the error message, + /// otherwise returns nullptr. + const char *getOutOfBandError() const { + return R.Size == 0 ? R.Data.ValuePtr : nullptr; + } + +private: + static void init(detail::CWrapperFunctionResult &R) { + R.Data.ValuePtr = nullptr; + R.Size = 0; + } + + detail::CWrapperFunctionResult R; +}; + +namespace detail { + +template +Expected +serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { + WrapperFunctionResult Result; + char *DataPtr = + WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...)); + SPSOutputBuffer OB(DataPtr, Result.size()); + if (!SPSArgListT::serialize(OB, Args...)) + return make_error( + "Error serializing arguments to blob in call", + inconvertibleErrorCode()); + return Result; +} + +template class WrapperFunctionHandlerCaller { +public: + template + static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, + std::index_sequence) { + return std::forward(H)(std::get(Args)...); + } +}; + +template <> class WrapperFunctionHandlerCaller { +public: + template + static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, + std::index_sequence) { + std::forward(H)(std::get(Args)...); + return SPSEmpty(); + } +}; + +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionHandlerHelper + : public WrapperFunctionHandlerHelper< + decltype(&std::remove_reference_t::operator()), + ResultSerializer, SPSTagTs...> {}; + +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionHandlerHelper { +public: + using ArgTuple = std::tuple...>; + using ArgIndices = std::make_index_sequence::value>; + + template + static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, + size_t ArgSize) { + ArgTuple Args; + if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) + return WrapperFunctionResult::createOutOfBandError( + "Could not deserialize arguments for wrapper function call"); + + auto HandlerResult = WrapperFunctionHandlerCaller::call( + std::forward(H), Args, ArgIndices{}); + + if (auto Result = ResultSerializer::serialize( + std::move(HandlerResult))) + return std::move(*Result); + else + return WrapperFunctionResult::createOutOfBandError( + toString(Result.takeError())); + } + +private: + template + static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, + std::index_sequence) { + SPSInputBuffer IB(ArgData, ArgSize); + return SPSArgList::deserialize(IB, std::get(Args)...); + } +}; + +// Map function references to function types. +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionHandlerHelper + : public WrapperFunctionHandlerHelper {}; + +// Map non-const member function types to function types. +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionHandlerHelper + : public WrapperFunctionHandlerHelper {}; + +// Map const member function types to function types. +template class ResultSerializer, typename... SPSTagTs> +class WrapperFunctionHandlerHelper + : public WrapperFunctionHandlerHelper {}; + +template class ResultSerializer { +public: + static Expected serialize(RetT Result) { + return serializeViaSPSToWrapperFunctionResult>( + Result); + } +}; + +template class ResultSerializer { +public: + static Expected serialize(Error Err) { + return serializeViaSPSToWrapperFunctionResult>( + toSPSSerializable(std::move(Err))); + } +}; + +template +class ResultSerializer> { +public: + static Expected serialize(Expected E) { + return serializeViaSPSToWrapperFunctionResult>( + toSPSSerializable(std::move(E))); + } +}; + +template class ResultDeserializer { +public: + static void makeSafe(RetT &Result) {} + + static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { + SPSInputBuffer IB(ArgData, ArgSize); + if (!SPSArgList::deserialize(IB, Result)) + return make_error( + "Error deserializing return value from blob in call", + inconvertibleErrorCode()); + return Error::success(); + } +}; + +template <> class ResultDeserializer { +public: + static void makeSafe(Error &Err) { cantFail(std::move(Err)); } + + static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { + SPSInputBuffer IB(ArgData, ArgSize); + SPSSerializableError BSE; + if (!SPSArgList::deserialize(IB, BSE)) + return make_error( + "Error deserializing return value from blob in call", + inconvertibleErrorCode()); + Err = fromSPSSerializable(std::move(BSE)); + return Error::success(); + } +}; + +template +class ResultDeserializer, Expected> { +public: + static void makeSafe(Expected &E) { cantFail(E.takeError()); } + + static Error deserialize(Expected &E, const char *ArgData, + size_t ArgSize) { + SPSInputBuffer IB(ArgData, ArgSize); + SPSSerializableExpected BSE; + if (!SPSArgList>::deserialize(IB, BSE)) + return make_error( + "Error deserializing return value from blob in call", + inconvertibleErrorCode()); + E = fromSPSSerializable(std::move(BSE)); + return Error::success(); + } +}; + +} // end namespace detail + +template class WrapperFunction; + +template +class WrapperFunction { +private: + template + using ResultSerializer = detail::ResultSerializer; + +public: + /// Call a wrapper function. Callere should be callable as + /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize); + template + static Error call(const CallerFn &Caller, RetT &Result, + const ArgTs &...Args) { + + // RetT might be an Error or Expected value. Set the checked flag now: + // we don't want the user to have to check the unused result if this + // operation fails. + detail::ResultDeserializer::makeSafe(Result); + + auto ArgBuffer = + detail::serializeViaSPSToWrapperFunctionResult>( + Args...); + if (!ArgBuffer) + return ArgBuffer.takeError(); + + WrapperFunctionResult ResultBuffer = + Caller(ArgBuffer->data(), ArgBuffer->size()); + if (auto ErrMsg = ResultBuffer.getOutOfBandError()) + return make_error(ErrMsg, inconvertibleErrorCode()); + + return detail::ResultDeserializer::deserialize( + Result, ResultBuffer.data(), ResultBuffer.size()); + } + + /// Handle a call to a wrapper function. + template + static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, + HandlerT &&Handler) { + using WFHH = + detail::WrapperFunctionHandlerHelper; + return WFHH::apply(std::forward(Handler), ArgData, ArgSize); + } + +private: + template static const T &makeSerializable(const T &Value) { + return Value; + } + + static detail::SPSSerializableError makeSerializable(Error Err) { + return detail::toSPSSerializable(std::move(Err)); + } + + template + static detail::SPSSerializableExpected makeSerializable(Expected E) { + return detail::toSPSSerializable(std::move(E)); + } +}; + +template +class WrapperFunction + : private WrapperFunction { +public: + template + static Error call(const void *FnTag, const ArgTs &...Args) { + SPSEmpty BE; + return WrapperFunction::call(FnTag, BE, Args...); + } + + using WrapperFunction::handle; +}; + +} // end namespace shared +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_WRAPPERFUNCTIONUTILS_H diff --git a/include/llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h b/include/llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h index 4188cea2b71..3fad98b5f17 100644 --- a/include/llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h +++ b/include/llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h @@ -13,10 +13,10 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_JITLOADERGDB_H #define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_JITLOADERGDB_H -#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h" #include -extern "C" llvm::orc::tpctypes::CWrapperFunctionResult -llvm_orc_registerJITLoaderGDBWrapper(uint8_t *Data, uint64_t Size); +extern "C" llvm::orc::shared::detail::CWrapperFunctionResult +llvm_orc_registerJITLoaderGDBWrapper(const char *Data, uint64_t Size); #endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_JITLOADERGDB_H diff --git a/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h b/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h index 639f30dd9b1..458947cc4d4 100644 --- a/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h +++ b/include/llvm/ExecutionEngine/Orc/TargetProcess/OrcRPCTPCServer.h @@ -17,6 +17,7 @@ #include "llvm/ExecutionEngine/Orc/Shared/RPCUtils.h" #include "llvm/ExecutionEngine/Orc/Shared/RawByteChannel.h" #include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h" #include "llvm/ExecutionEngine/Orc/TargetProcess/TargetExecutionUtils.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/FormatVariadic.h" @@ -135,7 +136,7 @@ public: static const char *getName() { return "ReleaseOrFinalizeMemRequestElement"; } }; -template <> class SerializationTypeName { +template <> class SerializationTypeName { public: static const char *getName() { return "WrapperFunctionResult"; } }; @@ -234,40 +235,25 @@ public: template class SerializationTraits< - ChannelT, tpctypes::WrapperFunctionResult, tpctypes::WrapperFunctionResult, + ChannelT, shared::WrapperFunctionResult, shared::WrapperFunctionResult, std::enable_if_t::value>> { public: - static Error serialize(ChannelT &C, - const tpctypes::WrapperFunctionResult &E) { - auto Data = E.getData(); - if (auto Err = serializeSeq(C, static_cast(Data.size()))) + static Error serialize(ChannelT &C, const shared::WrapperFunctionResult &E) { + if (auto Err = serializeSeq(C, static_cast(E.size()))) return Err; - if (Data.size() == 0) + if (E.size() == 0) return Error::success(); - return C.appendBytes(reinterpret_cast(Data.data()), - Data.size()); + return C.appendBytes(E.data(), E.size()); } - static Error deserialize(ChannelT &C, tpctypes::WrapperFunctionResult &E) { - tpctypes::CWrapperFunctionResult R; + static Error deserialize(ChannelT &C, shared::WrapperFunctionResult &E) { - R.Size = 0; - R.Data.ValuePtr = nullptr; - R.Destroy = nullptr; - - if (auto Err = deserializeSeq(C, R.Size)) + uint64_t Size; + if (auto Err = deserializeSeq(C, Size)) return Err; - if (R.Size == 0) - return Error::success(); - R.Data.ValuePtr = new uint8_t[R.Size]; - if (auto Err = - C.readBytes(reinterpret_cast(R.Data.ValuePtr), R.Size)) { - R.Destroy = tpctypes::WrapperFunctionResult::destroyWithDeleteArray; - return Err; - } - E = tpctypes::WrapperFunctionResult(R); - return Error::success(); + char *DataPtr = shared::WrapperFunctionResult::allocate(E, Size); + return C.readBytes(DataPtr, E.size()); } }; @@ -371,7 +357,7 @@ public: class RunWrapper : public shared::RPCFunction)> { public: static const char *getName() { return "RunWrapper"; } @@ -594,13 +580,14 @@ private: ProgramNameOverride); } - tpctypes::WrapperFunctionResult + shared::WrapperFunctionResult runWrapper(JITTargetAddress WrapperFnAddr, const std::vector &ArgBuffer) { - using WrapperFnTy = tpctypes::CWrapperFunctionResult (*)( - const uint8_t *Data, uint64_t Size); + using WrapperFnTy = shared::detail::CWrapperFunctionResult (*)( + const char *Data, uint64_t Size); auto *WrapperFn = jitTargetAddressToFunction(WrapperFnAddr); - return WrapperFn(ArgBuffer.data(), ArgBuffer.size()); + return WrapperFn(reinterpret_cast(ArgBuffer.data()), + ArgBuffer.size()); } void closeConnection() { Finished = true; } diff --git a/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h b/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h index 14d30c4aa34..3b4aabb9037 100644 --- a/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h +++ b/include/llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h @@ -9,12 +9,14 @@ // Support for dynamically registering and deregistering eh-frame sections // in-process via libunwind. // +// FIXME: The functionality in this file should be moved to the ORC runtime. +// //===----------------------------------------------------------------------===// #ifndef LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H #define LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H -#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h" #include "llvm/Support/Error.h" namespace llvm { @@ -31,10 +33,10 @@ Error deregisterEHFrameSection(const void *EHFrameSectionAddr, } // end namespace orc } // end namespace llvm -extern "C" llvm::orc::tpctypes::CWrapperFunctionResult -llvm_orc_registerEHFrameSectionWrapper(uint8_t *Data, uint64_t Size); +extern "C" llvm::orc::shared::detail::CWrapperFunctionResult +llvm_orc_registerEHFrameSectionWrapper(const char *Data, uint64_t Size); -extern "C" llvm::orc::tpctypes::CWrapperFunctionResult -llvm_orc_deregisterEHFrameSectionWrapper(uint8_t *Data, uint64_t Size); +extern "C" llvm::orc::shared::detail::CWrapperFunctionResult +llvm_orc_deregisterEHFrameSectionWrapper(const char *Data, uint64_t Size); #endif // LLVM_EXECUTIONENGINE_ORC_TARGETPROCESS_REGISTEREHFRAMES_H diff --git a/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h b/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h index b60b1ca6e37..774ec278045 100644 --- a/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h +++ b/include/llvm/ExecutionEngine/Orc/TargetProcessControl.h @@ -19,6 +19,7 @@ #include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/MSVCErrorWorkarounds.h" @@ -137,13 +138,13 @@ public: virtual Expected runAsMain(JITTargetAddress MainFnAddr, ArrayRef Args) = 0; - /// Run a wrapper function with signature: + /// Run a wrapper function in the executor. /// /// \code{.cpp} /// CWrapperFunctionResult fn(uint8_t *Data, uint64_t Size); /// \endcode{.cpp} /// - virtual Expected + virtual Expected runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef ArgBuffer) = 0; /// Disconnect from the target process. @@ -185,7 +186,7 @@ public: Expected runAsMain(JITTargetAddress MainFnAddr, ArrayRef Args) override; - Expected + Expected runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef ArgBuffer) override; diff --git a/lib/ExecutionEngine/Orc/Shared/CMakeLists.txt b/lib/ExecutionEngine/Orc/Shared/CMakeLists.txt index 62da0c71fb3..dddfda1a895 100644 --- a/lib/ExecutionEngine/Orc/Shared/CMakeLists.txt +++ b/lib/ExecutionEngine/Orc/Shared/CMakeLists.txt @@ -1,7 +1,6 @@ add_llvm_component_library(LLVMOrcShared OrcError.cpp RPCError.cpp - TargetProcessControlTypes.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc diff --git a/lib/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.cpp b/lib/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.cpp deleted file mode 100644 index 52d11f0741d..00000000000 --- a/lib/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.cpp +++ /dev/null @@ -1,44 +0,0 @@ -//===---------- TargetProcessControlTypes.cpp - Shared TPC types ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// TargetProcessControl types. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" - -namespace llvm { -namespace orc { -namespace tpctypes { - -WrapperFunctionResult WrapperFunctionResult::from(StringRef S) { - CWrapperFunctionResult R; - zeroInit(R); - R.Size = S.size(); - if (R.Size > sizeof(uint64_t)) { - R.Data.ValuePtr = new uint8_t[R.Size]; - memcpy(R.Data.ValuePtr, S.data(), R.Size); - R.Destroy = destroyWithDeleteArray; - } else - memcpy(R.Data.Value, S.data(), R.Size); - return R; -} - -void WrapperFunctionResult::destroyWithFree(CWrapperFunctionResultData Data, - uint64_t Size) { - free(Data.ValuePtr); -} - -void WrapperFunctionResult::destroyWithDeleteArray( - CWrapperFunctionResultData Data, uint64_t Size) { - delete[] Data.ValuePtr; -} - -} // end namespace tpctypes -} // end namespace orc -} // end namespace llvm diff --git a/lib/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.cpp b/lib/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.cpp index 30f833ec29b..55b1220a035 100644 --- a/lib/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.cpp +++ b/lib/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.cpp @@ -68,26 +68,11 @@ using namespace llvm; // Serialize rendezvous with the debugger as well as access to shared data. ManagedStatic JITDebugLock; -static std::pair readDebugObjectInfo(uint8_t *ArgData, - uint64_t ArgSize) { - BinaryStreamReader ArgReader(ArrayRef(ArgData, ArgSize), - support::endianness::big); - uint64_t Addr, Size; - cantFail(ArgReader.readInteger(Addr)); - cantFail(ArgReader.readInteger(Size)); - - return std::make_pair(jitTargetAddressToPointer(Addr), Size); -} - -extern "C" orc::tpctypes::CWrapperFunctionResult -llvm_orc_registerJITLoaderGDBWrapper(uint8_t *Data, uint64_t Size) { - if (Size != sizeof(uint64_t) + sizeof(uint64_t)) - return orc::tpctypes::WrapperFunctionResult::from( - "Invalid arguments to llvm_orc_registerJITLoaderGDBWrapper") - .release(); - +// Register debug object, return error message or null for success. +static void registerJITLoaderGDBImpl(JITTargetAddress Addr, uint64_t Size) { jit_code_entry *E = new jit_code_entry; - std::tie(E->symfile_addr, E->symfile_size) = readDebugObjectInfo(Data, Size); + E->symfile_addr = jitTargetAddressToPointer(Addr); + E->symfile_size = Size; E->prev_entry = nullptr; std::lock_guard Lock(*JITDebugLock); @@ -105,6 +90,12 @@ llvm_orc_registerJITLoaderGDBWrapper(uint8_t *Data, uint64_t Size) { // Run into the rendezvous breakpoint. __jit_debug_descriptor.action_flag = JIT_REGISTER_FN; __jit_debug_register_code(); - - return orc::tpctypes::WrapperFunctionResult().release(); +} + +extern "C" orc::shared::detail::CWrapperFunctionResult +llvm_orc_registerJITLoaderGDBWrapper(const char *Data, uint64_t Size) { + using namespace orc::shared; + return WrapperFunction::handle( + Data, Size, registerJITLoaderGDBImpl) + .release(); } diff --git a/lib/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.cpp b/lib/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.cpp index aff7296cb6e..9463a36668c 100644 --- a/lib/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.cpp +++ b/lib/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.cpp @@ -23,7 +23,7 @@ using namespace llvm; using namespace llvm::orc; -using namespace llvm::orc::tpctypes; +using namespace llvm::orc::shared; namespace llvm { namespace orc { @@ -155,54 +155,26 @@ Error deregisterEHFrameSection(const void *EHFrameSectionAddr, } // end namespace orc } // end namespace llvm -extern "C" CWrapperFunctionResult -llvm_orc_registerEHFrameSectionWrapper(uint8_t *Data, uint64_t Size) { - if (Size != sizeof(uint64_t) + sizeof(uint64_t)) - return WrapperFunctionResult::from( - "Invalid arguments to llvm_orc_registerEHFrameSectionWrapper") - .release(); - - uint64_t EHFrameSectionAddr; - uint64_t EHFrameSectionSize; - - { - BinaryStreamReader ArgReader(ArrayRef(Data, Size), - support::endianness::big); - cantFail(ArgReader.readInteger(EHFrameSectionAddr)); - cantFail(ArgReader.readInteger(EHFrameSectionSize)); - } - - if (auto Err = registerEHFrameSection( - jitTargetAddressToPointer(EHFrameSectionAddr), - EHFrameSectionSize)) { - auto ErrMsg = toString(std::move(Err)); - return WrapperFunctionResult::from(ErrMsg).release(); - } - return WrapperFunctionResult().release(); +static Error registerEHFrameWrapper(JITTargetAddress Addr, uint64_t Size) { + return llvm::orc::registerEHFrameSection( + jitTargetAddressToPointer(Addr), Size); } -extern "C" CWrapperFunctionResult -llvm_orc_deregisterEHFrameSectionWrapper(uint8_t *Data, uint64_t Size) { - if (Size != sizeof(uint64_t) + sizeof(uint64_t)) - return WrapperFunctionResult::from( - "Invalid arguments to llvm_orc_registerEHFrameSectionWrapper") - .release(); - - uint64_t EHFrameSectionAddr; - uint64_t EHFrameSectionSize; - - { - BinaryStreamReader ArgReader(ArrayRef(Data, Size), - support::endianness::big); - cantFail(ArgReader.readInteger(EHFrameSectionAddr)); - cantFail(ArgReader.readInteger(EHFrameSectionSize)); - } - - if (auto Err = deregisterEHFrameSection( - jitTargetAddressToPointer(EHFrameSectionAddr), - EHFrameSectionSize)) { - auto ErrMsg = toString(std::move(Err)); - return WrapperFunctionResult::from(ErrMsg).release(); - } - return WrapperFunctionResult().release(); +static Error deregisterEHFrameWrapper(JITTargetAddress Addr, uint64_t Size) { + return llvm::orc::deregisterEHFrameSection( + jitTargetAddressToPointer(Addr), Size); +} + +extern "C" orc::shared::detail::CWrapperFunctionResult +llvm_orc_registerEHFrameSectionWrapper(const char *Data, uint64_t Size) { + return WrapperFunction::handle( + Data, Size, registerEHFrameWrapper) + .release(); +} + +extern "C" orc::shared::detail::CWrapperFunctionResult +llvm_orc_deregisterEHFrameSectionWrapper(const char *Data, uint64_t Size) { + return WrapperFunction::handle( + Data, Size, deregisterEHFrameWrapper) + .release(); } diff --git a/lib/ExecutionEngine/Orc/TargetProcessControl.cpp b/lib/ExecutionEngine/Orc/TargetProcessControl.cpp index 7bf874e88c2..240adb41d62 100644 --- a/lib/ExecutionEngine/Orc/TargetProcessControl.cpp +++ b/lib/ExecutionEngine/Orc/TargetProcessControl.cpp @@ -102,13 +102,14 @@ SelfTargetProcessControl::runAsMain(JITTargetAddress MainFnAddr, return orc::runAsMain(jitTargetAddressToFunction(MainFnAddr), Args); } -Expected +Expected SelfTargetProcessControl::runWrapper(JITTargetAddress WrapperFnAddr, ArrayRef ArgBuffer) { - using WrapperFnTy = - tpctypes::CWrapperFunctionResult (*)(const uint8_t *Data, uint64_t Size); + using WrapperFnTy = shared::detail::CWrapperFunctionResult (*)( + const char *Data, uint64_t Size); auto *WrapperFn = jitTargetAddressToFunction(WrapperFnAddr); - return WrapperFn(ArgBuffer.data(), ArgBuffer.size()); + return WrapperFn(reinterpret_cast(ArgBuffer.data()), + ArgBuffer.size()); } Error SelfTargetProcessControl::disconnect() { return Error::success(); } diff --git a/unittests/ExecutionEngine/Orc/CMakeLists.txt b/unittests/ExecutionEngine/Orc/CMakeLists.txt index 088e1c8c2d8..b1cfd18e5d4 100644 --- a/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -26,8 +26,10 @@ add_llvm_unittest(OrcJITTests ResourceTrackerTest.cpp RPCUtilsTest.cpp RTDyldObjectLinkingLayerTest.cpp + SimplePackedSerializationTest.cpp SymbolStringPoolTest.cpp ThreadSafeModuleTest.cpp + WrapperFunctionUtilsTest.cpp ) target_link_libraries(OrcJITTests PRIVATE diff --git a/unittests/ExecutionEngine/Orc/SimplePackedSerializationTest.cpp b/unittests/ExecutionEngine/Orc/SimplePackedSerializationTest.cpp new file mode 100644 index 00000000000..5c784c16a4c --- /dev/null +++ b/unittests/ExecutionEngine/Orc/SimplePackedSerializationTest.cpp @@ -0,0 +1,160 @@ +//===-------- SimplePackedSerializationTest.cpp - Test SPS scheme ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::orc::shared; + +TEST(SimplePackedSerializationTest, SPSOutputBuffer) { + constexpr unsigned NumBytes = 8; + char Buffer[NumBytes]; + char Zero = 0; + SPSOutputBuffer OB(Buffer, NumBytes); + + // Expect that we can write NumBytes of content. + for (unsigned I = 0; I != NumBytes; ++I) { + char C = I; + EXPECT_TRUE(OB.write(&C, 1)); + } + + // Expect an error when we attempt to write an extra byte. + EXPECT_FALSE(OB.write(&Zero, 1)); + + // Check that the buffer contains the expected content. + for (unsigned I = 0; I != NumBytes; ++I) + EXPECT_EQ(Buffer[I], (char)I); +} + +TEST(SimplePackedSerializationTest, SPSInputBuffer) { + char Buffer[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}; + SPSInputBuffer IB(Buffer, sizeof(Buffer)); + + char C; + for (unsigned I = 0; I != sizeof(Buffer); ++I) { + EXPECT_TRUE(IB.read(&C, 1)); + EXPECT_EQ(C, (char)I); + } + + EXPECT_FALSE(IB.read(&C, 1)); +} + +template +static void blobSerializationRoundTrip(const T &Value) { + using BST = SPSSerializationTraits; + + size_t Size = BST::size(Value); + auto Buffer = std::make_unique(Size); + SPSOutputBuffer OB(Buffer.get(), Size); + + EXPECT_TRUE(BST::serialize(OB, Value)); + + SPSInputBuffer IB(Buffer.get(), Size); + + T DSValue; + EXPECT_TRUE(BST::deserialize(IB, DSValue)); + + EXPECT_EQ(Value, DSValue) + << "Incorrect value after serialization/deserialization round-trip"; +} + +template static void testFixedIntegralTypeSerialization() { + blobSerializationRoundTrip(0); + blobSerializationRoundTrip(static_cast(1)); + if (std::is_signed::value) { + blobSerializationRoundTrip(static_cast(-1)); + blobSerializationRoundTrip(std::numeric_limits::min()); + } + blobSerializationRoundTrip(std::numeric_limits::max()); +} + +TEST(SimplePackedSerializationTest, BoolSerialization) { + blobSerializationRoundTrip(true); + blobSerializationRoundTrip(false); +} + +TEST(SimplePackedSerializationTest, CharSerialization) { + blobSerializationRoundTrip((char)0x00); + blobSerializationRoundTrip((char)0xAA); + blobSerializationRoundTrip((char)0xFF); +} + +TEST(SimplePackedSerializationTest, Int8Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, UInt8Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, Int16Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, UInt16Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, Int32Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, UInt32Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, Int64Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, UInt64Serialization) { + testFixedIntegralTypeSerialization(); +} + +TEST(SimplePackedSerializationTest, SequenceSerialization) { + std::vector V({1, 2, -47, 139}); + blobSerializationRoundTrip, std::vector>(V); +} + +TEST(SimplePackedSerializationTest, StringViewCharSequenceSerialization) { + const char *HW = "Hello, world!"; + blobSerializationRoundTrip(StringRef(HW)); +} + +TEST(SimplePackedSerializationTest, StdPairSerialization) { + std::pair P(42, "foo"); + blobSerializationRoundTrip, + std::pair>(P); +} + +TEST(SimplePackedSerializationTest, ArgListSerialization) { + using BAL = SPSArgList; + + bool Arg1 = true; + int32_t Arg2 = 42; + std::string Arg3 = "foo"; + + size_t Size = BAL::size(Arg1, Arg2, Arg3); + auto Buffer = std::make_unique(Size); + SPSOutputBuffer OB(Buffer.get(), Size); + + EXPECT_TRUE(BAL::serialize(OB, Arg1, Arg2, Arg3)); + + SPSInputBuffer IB(Buffer.get(), Size); + + bool ArgOut1; + int32_t ArgOut2; + std::string ArgOut3; + + EXPECT_TRUE(BAL::deserialize(IB, ArgOut1, ArgOut2, ArgOut3)); + + EXPECT_EQ(Arg1, ArgOut1); + EXPECT_EQ(Arg2, ArgOut2); + EXPECT_EQ(Arg3, ArgOut3); +} diff --git a/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp b/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp new file mode 100644 index 00000000000..d93637ea128 --- /dev/null +++ b/unittests/ExecutionEngine/Orc/WrapperFunctionUtilsTest.cpp @@ -0,0 +1,75 @@ +//===----- WrapperFunctionUtilsTest.cpp - Test Wrapper-Function utils -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::orc::shared; + +namespace { +constexpr const char *TestString = "test string"; +} // end anonymous namespace + +TEST(WrapperFunctionUtilsTest, DefaultWrapperFunctionResult) { + WrapperFunctionResult R; + EXPECT_TRUE(R.empty()); + EXPECT_EQ(R.size(), 0U); + EXPECT_EQ(R.getOutOfBandError(), nullptr); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromRange) { + auto R = WrapperFunctionResult::copyFrom(TestString, strlen(TestString) + 1); + EXPECT_EQ(R.size(), strlen(TestString) + 1); + EXPECT_TRUE(strcmp(R.data(), TestString) == 0); + EXPECT_FALSE(R.empty()); + EXPECT_EQ(R.getOutOfBandError(), nullptr); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCString) { + auto R = WrapperFunctionResult::copyFrom(TestString); + EXPECT_EQ(R.size(), strlen(TestString) + 1); + EXPECT_TRUE(strcmp(R.data(), TestString) == 0); + EXPECT_FALSE(R.empty()); + EXPECT_EQ(R.getOutOfBandError(), nullptr); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromStdString) { + auto R = WrapperFunctionResult::copyFrom(std::string(TestString)); + EXPECT_EQ(R.size(), strlen(TestString) + 1); + EXPECT_TRUE(strcmp(R.data(), TestString) == 0); + EXPECT_FALSE(R.empty()); + EXPECT_EQ(R.getOutOfBandError(), nullptr); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromOutOfBandError) { + auto R = WrapperFunctionResult::createOutOfBandError(TestString); + EXPECT_FALSE(R.empty()); + EXPECT_TRUE(strcmp(R.getOutOfBandError(), TestString) == 0); +} + +static WrapperFunctionResult voidNoopWrapper(const char *ArgData, + size_t ArgSize) { + return WrapperFunction::handle(ArgData, ArgSize, voidNoop); +} + +static WrapperFunctionResult addWrapper(const char *ArgData, size_t ArgSize) { + return WrapperFunction::handle( + ArgData, ArgSize, [](int32_t X, int32_t Y) -> int32_t { return X + Y; }); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) { + EXPECT_FALSE(!!WrapperFunction::call((void *)&voidNoopWrapper)); +} + +TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAndHandle) { + int32_t Result; + EXPECT_FALSE(!!WrapperFunction::call( + addWrapper, Result, 1, 2)); + EXPECT_EQ(Result, (int32_t)3); +}