mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-25 12:12:47 +01:00
[TFUtils] Expose untyped accessor to evaluation result tensors
These were implementation detail, but become necessary for generic data copying. Also added const variations to them, and move assignment, since we had a move ctor (and the move assignment helps in a subsequent patch). Differential Revision: https://reviews.llvm.org/D85262
This commit is contained in:
parent
8e671cc375
commit
3bd1a7f753
@ -101,18 +101,29 @@ public:
|
||||
class EvaluationResult {
|
||||
public:
|
||||
EvaluationResult(const EvaluationResult &) = delete;
|
||||
EvaluationResult &operator=(const EvaluationResult &Other) = delete;
|
||||
|
||||
EvaluationResult(EvaluationResult &&Other);
|
||||
EvaluationResult &operator=(EvaluationResult &&Other);
|
||||
|
||||
~EvaluationResult();
|
||||
|
||||
/// Get a pointer to the first element of the tensor at Index.
|
||||
/// Get a (const) pointer to the first element of the tensor at Index.
|
||||
template <typename T> T *getTensorValue(size_t Index) {
|
||||
return static_cast<T *>(getUntypedTensorValue(Index));
|
||||
}
|
||||
|
||||
template <typename T> const T *getTensorValue(size_t Index) const {
|
||||
return static_cast<T *>(getUntypedTensorValue(Index));
|
||||
}
|
||||
|
||||
/// Get a (const) pointer to the untyped data of the tensor.
|
||||
void *getUntypedTensorValue(size_t Index);
|
||||
const void *getUntypedTensorValue(size_t Index) const;
|
||||
|
||||
private:
|
||||
friend class TFModelEvaluator;
|
||||
EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl);
|
||||
void *getUntypedTensorValue(size_t Index);
|
||||
std::unique_ptr<EvaluationResultImpl> Impl;
|
||||
};
|
||||
|
||||
|
@ -292,10 +292,21 @@ TFModelEvaluator::EvaluationResult::EvaluationResult(
|
||||
TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
|
||||
: Impl(std::move(Other.Impl)) {}
|
||||
|
||||
TFModelEvaluator::EvaluationResult &
|
||||
TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
|
||||
Impl = std::move(Other.Impl);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
|
||||
return TF_TensorData(Impl->getOutput()[Index]);
|
||||
}
|
||||
|
||||
const void *
|
||||
TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
|
||||
return TF_TensorData(Impl->getOutput()[Index]);
|
||||
}
|
||||
|
||||
#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \
|
||||
template <> int TensorSpec::getDataType<T>() { return TF_##E; }
|
||||
|
||||
|
@ -56,6 +56,8 @@ TEST(TFUtilsTest, LoadAndExecuteTest) {
|
||||
EXPECT_TRUE(ER.hasValue());
|
||||
float Ret = *ER->getTensorValue<float>(0);
|
||||
EXPECT_EQ(static_cast<size_t>(Ret), 80);
|
||||
EXPECT_EQ(ER->getUntypedTensorValue(0),
|
||||
reinterpret_cast<const void *>(ER->getTensorValue<float>(0)));
|
||||
}
|
||||
// The input vector should be unchanged
|
||||
for (auto I = 0; I < KnownSize; ++I) {
|
||||
@ -137,4 +139,4 @@ TEST(TFUtilsTest, TensorSpecSizesAndTypes) {
|
||||
EXPECT_EQ(Spec3DLarge.getElementCount(), 80);
|
||||
EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
|
||||
EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user