From e529b62823131b956fec0c480825e5303f4428ec Mon Sep 17 00:00:00 2001 From: Nick Lewycky Date: Tue, 9 Mar 2021 15:37:04 -0800 Subject: [PATCH] Add ConstantDataVector::getRaw() to create a constant data vector from raw data. This parallels ConstantDataArray::getRaw() and can be used with ConstantDataSequential::getRawDataValues() in the base class for both types. Update BuildConstantData{Array,Vector} tests to test the getRaw API. Also removes its unused Module. In passing, update some comments to include the support for half and bfloat. Update tests to include testing for bfloat. Differential Revision: https://reviews.llvm.org/D98302 --- include/llvm/IR/Constants.h | 28 +++++++++++++++------ unittests/IR/ConstantsTest.cpp | 46 +++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/include/llvm/IR/Constants.h b/include/llvm/IR/Constants.h index 510163abe6e..223e47aa84e 100644 --- a/include/llvm/IR/Constants.h +++ b/include/llvm/IR/Constants.h @@ -558,10 +558,10 @@ public: //===----------------------------------------------------------------------===// /// ConstantDataSequential - A vector or array constant whose element type is a -/// simple 1/2/4/8-byte integer or float/double, and whose elements are just -/// simple data values (i.e. ConstantInt/ConstantFP). This Constant node has no -/// operands because it stores all of the elements of the constant as densely -/// packed data, instead of as Value*'s. +/// simple 1/2/4/8-byte integer or half/bfloat/float/double, and whose elements +/// are just simple data values (i.e. ConstantInt/ConstantFP). This Constant +/// node has no operands because it stores all of the elements of the constant +/// as densely packed data, instead of as Value*'s. /// /// This is the common base class of ConstantDataArray and ConstantDataVector. /// @@ -700,11 +700,11 @@ public: return ConstantDataArray::get(Context, makeArrayRef(Elts)); } - /// get() constructor - Return a constant with array type with an element + /// getRaw() constructor - Return a constant with array type with an element /// count and element type matching the NumElements and ElementTy parameters /// passed in. Note that this can return a ConstantAggregateZero object. - /// ElementTy needs to be one of i8/i16/i32/i64/float/double. Data is the - /// buffer containing the elements. Be careful to make sure Data uses the + /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is + /// the buffer containing the elements. Be careful to make sure Data uses the /// right endianness, the buffer will be used as-is. static Constant *getRaw(StringRef Data, uint64_t NumElements, Type *ElementTy) { @@ -772,6 +772,18 @@ public: static Constant *get(LLVMContext &Context, ArrayRef Elts); static Constant *get(LLVMContext &Context, ArrayRef Elts); + /// getRaw() constructor - Return a constant with vector type with an element + /// count and element type matching the NumElements and ElementTy parameters + /// passed in. Note that this can return a ConstantAggregateZero object. + /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is + /// the buffer containing the elements. Be careful to make sure Data uses the + /// right endianness, the buffer will be used as-is. + static Constant *getRaw(StringRef Data, uint64_t NumElements, + Type *ElementTy) { + Type *Ty = VectorType::get(ElementTy, ElementCount::getFixed(NumElements)); + return getImpl(Data, Ty); + } + /// getFP() constructors - Return a constant of vector type with a float /// element type taken from argument `ElementType', and count taken from /// argument `Elts'. The amount of bits of the contained type must match the @@ -784,7 +796,7 @@ public: /// Return a ConstantVector with the specified constant in each element. /// The specified constant has to be a of a compatible type (i8/i16/ - /// i32/i64/float/double) and must be a ConstantFP or ConstantInt. + /// i32/i64/half/bfloat/float/double) and must be a ConstantFP or ConstantInt. static Constant *getSplat(unsigned NumElts, Constant *Elt); /// Returns true if this is a splat constant, meaning that all elements have diff --git a/unittests/IR/ConstantsTest.cpp b/unittests/IR/ConstantsTest.cpp index 44dbb90758a..50eb3e0df1f 100644 --- a/unittests/IR/ConstantsTest.cpp +++ b/unittests/IR/ConstantsTest.cpp @@ -418,45 +418,55 @@ static std::string getNameOfType(Type *T) { TEST(ConstantsTest, BuildConstantDataArrays) { LLVMContext Context; - std::unique_ptr M(new Module("MyModule", Context)); for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context), Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) { ArrayType *ArrayTy = ArrayType::get(T, 2); Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)}; - Constant *CDV = ConstantArray::get(ArrayTy, Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CA = ConstantArray::get(ArrayTy, Vals); + ASSERT_TRUE(isa(CA)) << " T = " << getNameOfType(T); + auto *CDA = cast(CA); + Constant *CA2 = ConstantDataArray::getRaw( + CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType()); + ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T); } - for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context), - Type::getDoubleTy(Context)}) { + for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context), + Type::getFloatTy(Context), Type::getDoubleTy(Context)}) { ArrayType *ArrayTy = ArrayType::get(T, 2); Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)}; - Constant *CDV = ConstantArray::get(ArrayTy, Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CA = ConstantArray::get(ArrayTy, Vals); + ASSERT_TRUE(isa(CA)) << " T = " << getNameOfType(T); + auto *CDA = cast(CA); + Constant *CA2 = ConstantDataArray::getRaw( + CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType()); + ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T); } } TEST(ConstantsTest, BuildConstantDataVectors) { LLVMContext Context; - std::unique_ptr M(new Module("MyModule", Context)); for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context), Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) { Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)}; - Constant *CDV = ConstantVector::get(Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CV = ConstantVector::get(Vals); + ASSERT_TRUE(isa(CV)) << " T = " << getNameOfType(T); + auto *CDV = cast(CV); + Constant *CV2 = ConstantDataVector::getRaw( + CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType()); + ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T); } - for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context), - Type::getDoubleTy(Context)}) { + for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context), + Type::getFloatTy(Context), Type::getDoubleTy(Context)}) { Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)}; - Constant *CDV = ConstantVector::get(Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CV = ConstantVector::get(Vals); + ASSERT_TRUE(isa(CV)) << " T = " << getNameOfType(T); + auto *CDV = cast(CV); + Constant *CV2 = ConstantDataVector::getRaw( + CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType()); + ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T); } }