diff --git a/lib/Support/KnownBits.cpp b/lib/Support/KnownBits.cpp index e10f97bb751..7b94e7b9594 100644 --- a/lib/Support/KnownBits.cpp +++ b/lib/Support/KnownBits.cpp @@ -160,12 +160,16 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) { return Known; } - // Minimum shift amount low bits are known zero. - if (RHS.getMinValue().ult(BitWidth)) - Known.Zero.setLowBits(RHS.getMinValue().getZExtValue()); - // No matter the shift amount, the trailing zeros will stay zero. - Known.Zero.setLowBits(LHS.countMinTrailingZeros()); + unsigned MinTrailingZeros = LHS.countMinTrailingZeros(); + + // Minimum shift amount low bits are known zero. + if (RHS.getMinValue().ult(BitWidth)) { + MinTrailingZeros += RHS.getMinValue().getZExtValue(); + MinTrailingZeros = std::min(MinTrailingZeros, BitWidth); + } + + Known.Zero.setLowBits(MinTrailingZeros); return Known; } @@ -183,12 +187,16 @@ KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS) { return Known; } - // Minimum shift amount high bits are known zero. - if (RHS.getMinValue().ult(BitWidth)) - Known.Zero.setHighBits(RHS.getMinValue().getZExtValue()); - // No matter the shift amount, the leading zeros will stay zero. - Known.Zero.setHighBits(LHS.countMinLeadingZeros()); + unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); + + // Minimum shift amount high bits are known zero. + if (RHS.getMinValue().ult(BitWidth)) { + MinLeadingZeros += RHS.getMinValue().getZExtValue(); + MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); + } + + Known.Zero.setHighBits(MinLeadingZeros); return Known; } @@ -204,8 +212,24 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) { return Known; } - // TODO: Minimum shift amount high bits are known sign bits. - // TODO: No matter the shift amount, the leading sign bits will stay. + // No matter the shift amount, the leading sign bits will stay. + unsigned MinLeadingZeros = LHS.countMinLeadingZeros(); + unsigned MinLeadingOnes = LHS.countMinLeadingOnes(); + + // Minimum shift amount high bits are known sign bits. + if (RHS.getMinValue().ult(BitWidth)) { + if (MinLeadingZeros) { + MinLeadingZeros += RHS.getMinValue().getZExtValue(); + MinLeadingZeros = std::min(MinLeadingZeros, BitWidth); + } + if (MinLeadingOnes) { + MinLeadingOnes += RHS.getMinValue().getZExtValue(); + MinLeadingOnes = std::min(MinLeadingOnes, BitWidth); + } + } + + Known.Zero.setHighBits(MinLeadingZeros); + Known.One.setHighBits(MinLeadingOnes); return Known; }