diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index be4379e8867..84a037b2f99 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -342,29 +342,35 @@ public: assert(!Name.empty() && "The empty string is reserved for the Success value"); - std::lock_guard Lock(SerializersMutex); - - // We're abusing the stability of std::map here: We take a reference to the - // key of the deserializers map to save us from duplicating the string in - // the serializer. This should be changed to use a stringpool if we switch - // to a map type that may move keys in memory. - auto I = - Deserializers.insert(Deserializers.begin(), - std::make_pair(std::move(Name), - std::move(Deserialize))); - - const std::string &KeyName = I->first; - // FIXME: Move capture Serialize once we have C++14. - Serializers[ErrorInfoT::classID()] = - [&KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { - assert(EIB.dynamicClassID() == ErrorInfoT::classID() && - "Serializer called for wrong error type"); - if (auto Err = serializeSeq(C, KeyName)) - return Err; - return Serialize(C, static_cast(EIB)); - }; + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to the + // key of the deserializers map to save us from duplicating the string in + // the serializer. This should be changed to use a stringpool if we switch + // to a map type that may move keys in memory. + std::lock_guard Lock(DeserializersMutex); + auto I = + Deserializers.insert(Deserializers.begin(), + std::make_pair(std::move(Name), + std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard Lock(SerializersMutex); + // FIXME: Move capture Serialize once we have C++14. + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast(EIB)); + }; + } } - + static Error serialize(ChannelT &C, Error &&Err) { std::lock_guard Lock(SerializersMutex); if (!Err) @@ -380,7 +386,7 @@ public: } static Error deserialize(ChannelT &C, Error &Err) { - std::lock_guard Lock(SerializersMutex); + std::lock_guard Lock(DeserializersMutex); std::string Key; if (auto Err = deserializeSeq(C, Key)) @@ -412,6 +418,7 @@ private: } static std::mutex SerializersMutex; + static std::mutex DeserializersMutex; static std::map Serializers; static std::map Deserializers; }; @@ -419,6 +426,9 @@ private: template std::mutex SerializationTraits::SerializersMutex; +template +std::mutex SerializationTraits::DeserializersMutex; + template std::map::WrappedErrorSerializer> diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 8f6dddc62cd..1c9764b555f 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -551,141 +551,139 @@ TEST(DummyRPC, TestWithAltCustomType) { ServerThread.join(); } -// FIXME: Temporarily disabled to investigate bot failure. +TEST(DummyRPC, ReturnErrorSuccess) { + registerDummyErrorSerialization(); -// TEST(DummyRPC, ReturnErrorSuccess) { -// registerDummyErrorSerialization(); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); -// auto Channels = createPairedQueueChannels(); -// DummyRPCEndpoint Client(*Channels.first); -// DummyRPCEndpoint Server(*Channels.second); + std::thread ServerThread([&]() { + Server.addHandler( + []() { + return Error::success(); + }); -// std::thread ServerThread([&]() { -// Server.addHandler( -// []() { -// return Error::success(); -// }); + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); -// // Handle the negotiate plus one call. -// for (unsigned I = 0; I != 2; ++I) -// cantFail(Server.handleOne()); -// }); + cantFail(Client.callAsync( + [&](Error Err) { + EXPECT_FALSE(!!Err) << "Expected success value"; + return Error::success(); + })); -// cantFail(Client.callAsync( -// [&](Error Err) { -// EXPECT_FALSE(!!Err) << "Expected success value"; -// return Error::success(); -// })); + cantFail(Client.handleOne()); -// cantFail(Client.handleOne()); + ServerThread.join(); +} -// ServerThread.join(); -// } +TEST(DummyRPC, ReturnErrorFailure) { + registerDummyErrorSerialization(); -// TEST(DummyRPC, ReturnErrorFailure) { -// registerDummyErrorSerialization(); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); -// auto Channels = createPairedQueueChannels(); -// DummyRPCEndpoint Client(*Channels.first); -// DummyRPCEndpoint Server(*Channels.second); + std::thread ServerThread([&]() { + Server.addHandler( + []() { + return make_error(42); + }); -// std::thread ServerThread([&]() { -// Server.addHandler( -// []() { -// return make_error(42); -// }); + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); -// // Handle the negotiate plus one call. -// for (unsigned I = 0; I != 2; ++I) -// cantFail(Server.handleOne()); -// }); + cantFail(Client.callAsync( + [&](Error Err) { + EXPECT_TRUE(Err.isA()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 42ULL) + << "Incorrect DummyError serialization"; + }); + })); -// cantFail(Client.callAsync( -// [&](Error Err) { -// EXPECT_TRUE(Err.isA()) -// << "Incorrect error type"; -// return handleErrors( -// std::move(Err), -// [](const DummyError &DE) { -// EXPECT_EQ(DE.getValue(), 42ULL) -// << "Incorrect DummyError serialization"; -// }); -// })); + cantFail(Client.handleOne()); -// cantFail(Client.handleOne()); + ServerThread.join(); +} -// ServerThread.join(); -// } +TEST(DummyRPC, ReturnExpectedSuccess) { + registerDummyErrorSerialization(); -// TEST(DummyRPC, ReturnExpectedSuccess) { -// registerDummyErrorSerialization(); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); -// auto Channels = createPairedQueueChannels(); -// DummyRPCEndpoint Client(*Channels.first); -// DummyRPCEndpoint Server(*Channels.second); + std::thread ServerThread([&]() { + Server.addHandler( + []() -> uint32_t { + return 42; + }); -// std::thread ServerThread([&]() { -// Server.addHandler( -// []() -> uint32_t { -// return 42; -// }); + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); -// // Handle the negotiate plus one call. -// for (unsigned I = 0; I != 2; ++I) -// cantFail(Server.handleOne()); -// }); + cantFail(Client.callAsync( + [&](Expected ValOrErr) { + EXPECT_TRUE(!!ValOrErr) + << "Expected success value"; + EXPECT_EQ(*ValOrErr, 42ULL) + << "Incorrect Expected deserialization"; + return Error::success(); + })); -// cantFail(Client.callAsync( -// [&](Expected ValOrErr) { -// EXPECT_TRUE(!!ValOrErr) -// << "Expected success value"; -// EXPECT_EQ(*ValOrErr, 42ULL) -// << "Incorrect Expected deserialization"; -// return Error::success(); -// })); + cantFail(Client.handleOne()); -// cantFail(Client.handleOne()); + ServerThread.join(); +} -// ServerThread.join(); -// } +TEST(DummyRPC, ReturnExpectedFailure) { + registerDummyErrorSerialization(); -// TEST(DummyRPC, ReturnExpectedFailure) { -// registerDummyErrorSerialization(); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); -// auto Channels = createPairedQueueChannels(); -// DummyRPCEndpoint Client(*Channels.first); -// DummyRPCEndpoint Server(*Channels.second); + std::thread ServerThread([&]() { + Server.addHandler( + []() -> Expected { + return make_error(7); + }); -// std::thread ServerThread([&]() { -// Server.addHandler( -// []() -> Expected { -// return make_error(7); -// }); + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); -// // Handle the negotiate plus one call. -// for (unsigned I = 0; I != 2; ++I) -// cantFail(Server.handleOne()); -// }); + cantFail(Client.callAsync( + [&](Expected ValOrErr) { + EXPECT_FALSE(!!ValOrErr) + << "Expected failure value"; + auto Err = ValOrErr.takeError(); + EXPECT_TRUE(Err.isA()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 7ULL) + << "Incorrect DummyError serialization"; + }); + })); -// cantFail(Client.callAsync( -// [&](Expected ValOrErr) { -// EXPECT_FALSE(!!ValOrErr) -// << "Expected failure value"; -// auto Err = ValOrErr.takeError(); -// EXPECT_TRUE(Err.isA()) -// << "Incorrect error type"; -// return handleErrors( -// std::move(Err), -// [](const DummyError &DE) { -// EXPECT_EQ(DE.getValue(), 7ULL) -// << "Incorrect DummyError serialization"; -// }); -// })); + cantFail(Client.handleOne()); -// cantFail(Client.handleOne()); - -// ServerThread.join(); -// } + ServerThread.join(); +} TEST(DummyRPC, TestParallelCallGroup) { auto Channels = createPairedQueueChannels();