1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-25 12:12:47 +01:00

[TFUtils] Expose untyped accessor to evaluation result tensors

These were implementation detail, but become necessary for generic data
copying.

Also added const variations to them, and move assignment, since we had a
move ctor (and the move assignment helps in a subsequent patch).

Differential Revision: https://reviews.llvm.org/D85262
This commit is contained in:
Mircea Trofin 2020-08-05 10:22:45 -07:00
parent 8e671cc375
commit 3bd1a7f753
3 changed files with 27 additions and 3 deletions

View File

@ -101,18 +101,29 @@ public:
class EvaluationResult {
public:
EvaluationResult(const EvaluationResult &) = delete;
EvaluationResult &operator=(const EvaluationResult &Other) = delete;
EvaluationResult(EvaluationResult &&Other);
EvaluationResult &operator=(EvaluationResult &&Other);
~EvaluationResult();
/// Get a pointer to the first element of the tensor at Index.
/// Get a (const) pointer to the first element of the tensor at Index.
template <typename T> T *getTensorValue(size_t Index) {
return static_cast<T *>(getUntypedTensorValue(Index));
}
template <typename T> const T *getTensorValue(size_t Index) const {
return static_cast<T *>(getUntypedTensorValue(Index));
}
/// Get a (const) pointer to the untyped data of the tensor.
void *getUntypedTensorValue(size_t Index);
const void *getUntypedTensorValue(size_t Index) const;
private:
friend class TFModelEvaluator;
EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl);
void *getUntypedTensorValue(size_t Index);
std::unique_ptr<EvaluationResultImpl> Impl;
};

View File

@ -292,10 +292,21 @@ TFModelEvaluator::EvaluationResult::EvaluationResult(
TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
: Impl(std::move(Other.Impl)) {}
TFModelEvaluator::EvaluationResult &
TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
Impl = std::move(Other.Impl);
return *this;
}
void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
return TF_TensorData(Impl->getOutput()[Index]);
}
const void *
TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
return TF_TensorData(Impl->getOutput()[Index]);
}
#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \
template <> int TensorSpec::getDataType<T>() { return TF_##E; }

View File

@ -56,6 +56,8 @@ TEST(TFUtilsTest, LoadAndExecuteTest) {
EXPECT_TRUE(ER.hasValue());
float Ret = *ER->getTensorValue<float>(0);
EXPECT_EQ(static_cast<size_t>(Ret), 80);
EXPECT_EQ(ER->getUntypedTensorValue(0),
reinterpret_cast<const void *>(ER->getTensorValue<float>(0)));
}
// The input vector should be unchanged
for (auto I = 0; I < KnownSize; ++I) {
@ -137,4 +139,4 @@ TEST(TFUtilsTest, TensorSpecSizesAndTypes) {
EXPECT_EQ(Spec3DLarge.getElementCount(), 80);
EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
}
}