diff --git a/include/llvm/IR/DerivedTypes.h b/include/llvm/IR/DerivedTypes.h index 7a0900d3c7d..4cac62150a2 100644 --- a/include/llvm/IR/DerivedTypes.h +++ b/include/llvm/IR/DerivedTypes.h @@ -62,6 +62,11 @@ public: /// Get or create an IntegerType instance. static IntegerType *get(LLVMContext &C, unsigned NumBits); + /// Returns type twice as wide the input type. + IntegerType *getExtendedType() const { + return Type::getIntNTy(getContext(), 2 * getScalarSizeInBits()); + } + /// Get the number of bits in this IntegerType unsigned getBitWidth() const { return getSubclassData(); } @@ -470,9 +475,9 @@ public: /// This static method is like getInteger except that the element types are /// twice as wide as the elements in the input type. static VectorType *getExtendedElementVectorType(VectorType *VTy) { - unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); - Type *EltTy = IntegerType::get(VTy->getContext(), EltBits * 2); - return VectorType::get(EltTy, VTy->getElementCount()); + assert(VTy->isIntOrIntVectorTy() && "VTy expected to be a vector of ints."); + auto *EltTy = cast(VTy->getElementType()); + return VectorType::get(EltTy->getExtendedType(), VTy->getElementCount()); } // This static method gets a VectorType with the same number of elements as @@ -603,6 +608,16 @@ public: } }; +Type *Type::getExtendedType() const { + assert( + isIntOrIntVectorTy() && + "Original type expected to be a vector of integers or a scalar integer."); + if (auto *VTy = dyn_cast(this)) + return VectorType::getExtendedElementVectorType( + const_cast(VTy)); + return cast(this)->getExtendedType(); +} + unsigned Type::getPointerAddressSpace() const { return cast(getScalarType())->getAddressSpace(); } diff --git a/include/llvm/IR/Type.h b/include/llvm/IR/Type.h index f2aa49030aa..34271aae49c 100644 --- a/include/llvm/IR/Type.h +++ b/include/llvm/IR/Type.h @@ -378,6 +378,10 @@ public: return ContainedTys[0]; } + /// Given scalar/vector integer type, returns a type with elements twice as + /// wide as in the original type. For vectors, preserves element count. + inline Type *getExtendedType() const; + /// Get the address space of this pointer or pointer vector type. inline unsigned getPointerAddressSpace() const; diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index ba77696f1f1..c88827f916c 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -173,11 +173,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return nullptr; // Else we can't perform the fold. // The mask must be computed in a type twice as wide to ensure // that no bits are lost if the sum-of-shifts is wider than the base type. - Type *ExtendedScalarTy = Type::getIntNTy(Ty->getContext(), 2 * BitWidth); - Type *ExtendedTy = - Ty->isVectorTy() - ? VectorType::get(ExtendedScalarTy, Ty->getVectorNumElements()) - : ExtendedScalarTy; + Type *ExtendedTy = Ty->getExtendedType(); auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); // And compute the mask as usual: ~(-1 << (SumOfShAmts)) @@ -213,11 +209,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, unsigned BitWidth = Ty->getScalarSizeInBits(); // The mask must be computed in a type twice as wide to ensure // that no bits are lost if the sum-of-shifts is wider than the base type. - Type *ExtendedScalarTy = Type::getIntNTy(Ty->getContext(), 2 * BitWidth); - Type *ExtendedTy = - Ty->isVectorTy() - ? VectorType::get(ExtendedScalarTy, Ty->getVectorNumElements()) - : ExtendedScalarTy; + Type *ExtendedTy = Ty->getExtendedType(); auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( ConstantExpr::getAdd( ConstantExpr::getNeg(ShAmtsDiff),