1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-10-19 19:12:56 +02:00

[Orc] Add some static-assert checks to improve the error messages for RPC calls

and handler registrations.

Also add a unit test for alternate-type serialization/deserialization.

llvm-svn: 290223
This commit is contained in:
Lang Hames 2016-12-21 00:59:33 +00:00
parent 525caf6bd1
commit 7277362100
2 changed files with 257 additions and 2 deletions

View File

@ -82,6 +82,17 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
template <typename DerivedFunc, typename RetT, typename... ArgTs> template <typename DerivedFunc, typename RetT, typename... ArgTs>
std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
/// Provides a typedef for a tuple containing the decayed argument types.
template <typename T>
class FunctionArgsTuple;
template <typename RetT, typename... ArgTs>
class FunctionArgsTuple<RetT(ArgTs...)> {
public:
using Type = std::tuple<typename std::decay<
typename std::remove_reference<ArgTs>::type>::type...>;
};
/// Allocates RPC function ids during autonegotiation. /// Allocates RPC function ids during autonegotiation.
/// Specializations of this class must provide four members: /// Specializations of this class must provide four members:
/// ///
@ -349,8 +360,7 @@ public:
using ReturnType = RetT; using ReturnType = RetT;
// A std::tuple wrapping the handler arguments. // A std::tuple wrapping the handler arguments.
using ArgStorage = std::tuple<typename std::decay< using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type;
typename std::remove_reference<ArgTs>::type>::type...>;
// Call the given handler with the given arguments. // Call the given handler with the given arguments.
template <typename HandlerT> template <typename HandlerT>
@ -589,6 +599,84 @@ private:
std::vector<SequenceNumberT> FreeSequenceNumbers; std::vector<SequenceNumberT> FreeSequenceNumbers;
}; };
// Checks that predicate P holds for each corresponding pair of type arguments
// from T1 and T2 tuple.
template <template<class, class> class P, typename T1Tuple,
typename T2Tuple>
class RPCArgTypeCheckHelper;
template <template<class, class> class P>
class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
public:
static const bool value = true;
};
template <template<class, class> class P, typename T, typename... Ts,
typename U, typename... Us>
class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
public:
static const bool value =
P<T, U>::value &&
RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
};
template <template<class, class> class P, typename T1Sig, typename T2Sig>
class RPCArgTypeCheck {
public:
using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
static_assert(std::tuple_size<T1Tuple>::value >= std::tuple_size<T2Tuple>::value,
"Too many arguments to RPC call");
static_assert(std::tuple_size<T1Tuple>::value <= std::tuple_size<T2Tuple>::value,
"Too few arguments to RPC call");
static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
};
template <typename ChannelT, typename WireT, typename ConcreteT>
class CanSerialize {
private:
using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
template <typename T>
static std::true_type
check(typename std::enable_if<
std::is_same<
decltype(T::serialize(std::declval<ChannelT&>(),
std::declval<const ConcreteT&>())),
Error>::value,
void*>::type);
template <typename>
static std::false_type check(...);
public:
static const bool value = decltype(check<S>(0))::value;
};
template <typename ChannelT, typename WireT, typename ConcreteT>
class CanDeserialize {
private:
using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
template <typename T>
static std::true_type
check(typename std::enable_if<
std::is_same<
decltype(T::deserialize(std::declval<ChannelT&>(),
std::declval<ConcreteT&>())),
Error>::value,
void*>::type);
template <typename>
static std::false_type check(...);
public:
static const bool value = decltype(check<S>(0))::value;
};
/// Contains primitive utilities for defining, calling and handling calls to /// Contains primitive utilities for defining, calling and handling calls to
/// remote procedures. ChannelT is a bidirectional stream conforming to the /// remote procedures. ChannelT is a bidirectional stream conforming to the
/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure /// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
@ -603,6 +691,7 @@ template <typename ImplT, typename ChannelT, typename FunctionIdT,
typename SequenceNumberT> typename SequenceNumberT>
class RPCBase { class RPCBase {
protected: protected:
class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
public: public:
static const char *getName() { return "__orc_rpc$invalid"; } static const char *getName() { return "__orc_rpc$invalid"; }
@ -619,6 +708,31 @@ protected:
static const char *getName() { return "__orc_rpc$negotiate"; } static const char *getName() { return "__orc_rpc$negotiate"; }
}; };
// Helper predicate for testing for the presence of SerializeTraits
// serializers.
template <typename WireT, typename ConcreteT>
class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
public:
using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
static_assert(value, "Missing serializer for argument (Can't serialize the "
"first template type argument of CanSerializeCheck "
"from the second)");
};
// Helper predicate for testing for the presence of SerializeTraits
// deserializers.
template <typename WireT, typename ConcreteT>
class CanDeserializeCheck
: detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
public:
using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
static_assert(value, "Missing deserializer for argument (Can't deserialize "
"the second template type argument of "
"CanDeserializeCheck from the first)");
};
public: public:
/// Construct an RPC instance on a channel. /// Construct an RPC instance on a channel.
RPCBase(ChannelT &C, bool LazyAutoNegotiation) RPCBase(ChannelT &C, bool LazyAutoNegotiation)
@ -643,6 +757,13 @@ public:
/// with an error if the return value is abandoned due to a channel error. /// with an error if the return value is abandoned due to a channel error.
template <typename Func, typename HandlerT, typename... ArgTs> template <typename Func, typename HandlerT, typename... ArgTs>
Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
static_assert(
detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
void(ArgTs...)>
::value,
"");
// Look up the function ID. // Look up the function ID.
FunctionIdT FnId; FunctionIdT FnId;
if (auto FnIdOrErr = getRemoteFunctionId<Func>()) if (auto FnIdOrErr = getRemoteFunctionId<Func>())
@ -738,6 +859,14 @@ protected:
/// autonegotiation and execution. /// autonegotiation and execution.
template <typename Func, typename HandlerT> template <typename Func, typename HandlerT>
void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) {
static_assert(
detail::RPCArgTypeCheck<CanDeserializeCheck,
typename Func::Type,
typename detail::HandlerTraits<HandlerT>::Type>
::value,
"");
FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
LocalFunctionIds[Func::getPrototype()] = NewFnId; LocalFunctionIds[Func::getPrototype()] = NewFnId;
Handlers[NewFnId] = Handlers[NewFnId] =

View File

@ -58,6 +58,40 @@ private:
Queue &OutQueue; Queue &OutQueue;
}; };
class RPCFoo {};
template <>
class RPCTypeName<RPCFoo> {
public:
static const char* getName() { return "RPCFoo"; }
};
template <>
class SerializationTraits<QueueChannel, RPCFoo, RPCFoo> {
public:
static Error serialize(QueueChannel&, const RPCFoo&) {
return Error::success();
}
static Error deserialize(QueueChannel&, RPCFoo&) {
return Error::success();
}
};
class RPCBar {};
template <>
class SerializationTraits<QueueChannel, RPCFoo, RPCBar> {
public:
static Error serialize(QueueChannel&, const RPCBar&) {
return Error::success();
}
static Error deserialize(QueueChannel&, RPCBar&) {
return Error::success();
}
};
class DummyRPCAPI { class DummyRPCAPI {
public: public:
@ -79,6 +113,12 @@ public:
public: public:
static const char* getName() { return "AllTheTypes"; } static const char* getName() { return "AllTheTypes"; }
}; };
class CustomType : public Function<CustomType, RPCFoo(RPCFoo)> {
public:
static const char* getName() { return "CustomType"; }
};
}; };
class DummyRPCEndpoint : public DummyRPCAPI, class DummyRPCEndpoint : public DummyRPCAPI,
@ -244,3 +284,89 @@ TEST(DummyRPC, TestSerialization) {
ServerThread.join(); ServerThread.join();
} }
TEST(DummyRPC, TestCustomType) {
Queue Q1, Q2;
DummyRPCEndpoint Client(Q1, Q2);
DummyRPCEndpoint Server(Q2, Q1);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::CustomType>(
[](RPCFoo F) {});
{
// Poke the server to handle the negotiate call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
}
{
// Poke the server to handle the CustomType call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)";
}
});
{
// Make an async call.
auto Err = Client.callAsync<DummyRPCAPI::CustomType>(
[](Expected<RPCFoo> FOrErr) {
EXPECT_TRUE(!!FOrErr)
<< "Async RPCFoo(RPCFoo) response handler failed";
return Error::success();
}, RPCFoo());
EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)";
}
{
// Poke the client to process the result of the RPCFoo() call.
auto Err = Client.handleOne();
EXPECT_FALSE(!!Err)
<< "Client failed to handle response from RPCFoo(RPCFoo)";
}
ServerThread.join();
}
TEST(DummyRPC, TestWithAltCustomType) {
Queue Q1, Q2;
DummyRPCEndpoint Client(Q1, Q2);
DummyRPCEndpoint Server(Q2, Q1);
std::thread ServerThread([&]() {
Server.addHandler<DummyRPCAPI::CustomType>(
[](RPCBar F) {});
{
// Poke the server to handle the negotiate call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate";
}
{
// Poke the server to handle the CustomType call.
auto Err = Server.handleOne();
EXPECT_FALSE(!!Err) << "Server failed to handle call to RPCFoo(RPCFoo)";
}
});
{
// Make an async call.
auto Err = Client.callAsync<DummyRPCAPI::CustomType>(
[](Expected<RPCBar> FOrErr) {
EXPECT_TRUE(!!FOrErr)
<< "Async RPCFoo(RPCFoo) response handler failed";
return Error::success();
}, RPCBar());
EXPECT_FALSE(!!Err) << "Client.callAsync failed for RPCFoo(RPCFoo)";
}
{
// Poke the client to process the result of the RPCFoo() call.
auto Err = Client.handleOne();
EXPECT_FALSE(!!Err)
<< "Client failed to handle response from RPCFoo(RPCFoo)";
}
ServerThread.join();
}