diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index be5cea41054..4e63a84ed17 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -338,7 +338,9 @@ public: return Err; // Close the response message. - return C.endSendMessage(); + if (auto Err = C.endSendMessage()) + return Err; + return C.send(); } template @@ -350,7 +352,9 @@ public: return Err2; if (auto Err2 = serializeSeq(C, std::move(Err))) return Err2; - return C.endSendMessage(); + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); } }; @@ -378,8 +382,11 @@ public: C, *ResultOrErr)) return Err; - // Close the response message. - return C.endSendMessage(); + // End the response message. + if (auto Err = C.endSendMessage()) + return Err; + + return C.send(); } template @@ -389,7 +396,9 @@ public: return Err; if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) return Err2; - return C.endSendMessage(); + if (auto Err2 = C.endSendMessage()) + return Err2; + return C.send(); } }; @@ -1520,6 +1529,12 @@ public: return std::move(Err); } + if (auto Err = this->C.send()) { + detail::ResultTraits::consumeAbandoned( + std::move(Result)); + return std::move(Err); + } + while (!ReceivedResponse) { if (auto Err = this->handleOne()) { detail::ResultTraits::consumeAbandoned( diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.h b/unittests/ExecutionEngine/Orc/QueueChannel.h index 511f038dec1..1909693ecb1 100644 --- a/unittests/ExecutionEngine/Orc/QueueChannel.h +++ b/unittests/ExecutionEngine/Orc/QueueChannel.h @@ -80,6 +80,30 @@ public: QueueChannel(QueueChannel&&) = delete; QueueChannel& operator=(QueueChannel&&) = delete; + template + Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { + ++InFlightOutgoingMessages; + return orc::rpc::RawByteChannel::startSendMessage(FnId, SeqNo); + } + + Error endSendMessage() { + --InFlightOutgoingMessages; + ++CompletedOutgoingMessages; + return orc::rpc::RawByteChannel::endSendMessage(); + } + + template + Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { + ++InFlightIncomingMessages; + return orc::rpc::RawByteChannel::startReceiveMessage(FnId, SeqNo); + } + + Error endReceiveMessage() { + --InFlightIncomingMessages; + ++CompletedIncomingMessages; + return orc::rpc::RawByteChannel::endReceiveMessage(); + } + Error readBytes(char *Dst, unsigned Size) override { std::unique_lock Lock(InQueue->getMutex()); while (Size) { @@ -112,7 +136,10 @@ public: return Error::success(); } - Error send() override { return Error::success(); } + Error send() override { + ++SendCalls; + return Error::success(); + } void close() { auto ChannelClosed = []() { return make_error(); }; @@ -124,6 +151,11 @@ public: uint64_t NumWritten = 0; uint64_t NumRead = 0; + std::atomic InFlightIncomingMessages{0}; + std::atomic CompletedIncomingMessages{0}; + std::atomic InFlightOutgoingMessages{0}; + std::atomic CompletedOutgoingMessages{0}; + std::atomic SendCalls{0}; private: diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 1f7c88d93d0..8e4c5330d90 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -214,6 +214,17 @@ TEST(DummyRPC, TestCallAsyncVoidBool) { EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; } + // The client should have made two calls to send: One implicit call to + // negotiate the VoidBool function key, and a second to make the VoidBool + // call. + EXPECT_EQ(Channels.first->SendCalls, 2U) + << "Expected one send call to have been made by client"; + + // The server should have made two calls to send: One to send the response to + // the negotiate call, and another to send the response to the VoidBool call. + EXPECT_EQ(Channels.second->SendCalls, 2U) + << "Expected two send calls to have been made by server"; + ServerThread.join(); }