From a334ec01fc93b4f3b6858b9a691e4edcf20a4fd5 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Sun, 14 Feb 2021 18:06:09 +0000 Subject: [PATCH] [ConstraintElimination] Fix variables used for pattern matching. Re-using the matched variable in the pattern does not work as expected. This patch fixes that by introducing a new variable for the 2nd level match. --- .../Scalar/ConstraintElimination.cpp | 10 ++-- test/Transforms/ConstraintElimination/geps.ll | 55 ++++++++++++++++++- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/lib/Transforms/Scalar/ConstraintElimination.cpp b/lib/Transforms/Scalar/ConstraintElimination.cpp index 09b0b4a8618..00f1c488ff5 100644 --- a/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -54,20 +54,20 @@ static SmallVector, 4> decompose(Value *V) { } auto *GEP = dyn_cast(V); if (GEP && GEP->getNumOperands() == 2 && GEP->isInBounds()) { - Value *Op0; + Value *Op0, *Op1; ConstantInt *CI; // If the index is zero-extended, it is guaranteed to be positive. if (match(GEP->getOperand(GEP->getNumOperands() - 1), m_ZExt(m_Value(Op0)))) { - if (match(Op0, m_NUWShl(m_Value(Op0), m_ConstantInt(CI)))) + if (match(Op0, m_NUWShl(m_Value(Op1), m_ConstantInt(CI)))) return {{0, nullptr}, {1, GEP->getPointerOperand()}, - {std::pow(int64_t(2), CI->getSExtValue()), Op0}}; - if (match(Op0, m_NSWAdd(m_Value(Op0), m_ConstantInt(CI)))) + {std::pow(int64_t(2), CI->getSExtValue()), Op1}}; + if (match(Op0, m_NSWAdd(m_Value(Op1), m_ConstantInt(CI)))) return {{CI->getSExtValue(), nullptr}, {1, GEP->getPointerOperand()}, - {1, Op0}}; + {1, Op1}}; return {{0, nullptr}, {1, GEP->getPointerOperand()}, {1, Op0}}; } diff --git a/test/Transforms/ConstraintElimination/geps.ll b/test/Transforms/ConstraintElimination/geps.ll index 5c891f3c390..9141ace2e56 100644 --- a/test/Transforms/ConstraintElimination/geps.ll +++ b/test/Transforms/ConstraintElimination/geps.ll @@ -516,7 +516,7 @@ if.end: ; preds = %entry } ; Test which requires decomposing GEP %ptr, SHL(). -define void @test.ult.gep.shl(i32* readonly %src, i32* readnone %max, i32 %idx, i32 %j) { +define void @test.ult.gep.shl(i32* readonly %src, i32* readnone %max, i32 %idx) { ; CHECK-LABEL: @test.ult.gep.shl( ; CHECK-NEXT: check.0.min: ; CHECK-NEXT: [[ADD_10:%.*]] = getelementptr inbounds i32, i32* [[SRC:%.*]], i32 10 @@ -646,4 +646,57 @@ check.max: ; preds = %check.0.min ret void } +; Make sure non-constant shift amounts are handled correctly. +define i1 @test.ult.gep.shl.nonconst.zext(i16 %B, i16* readonly %src, i16* readnone %max, i16 %idx, i16 %j) { +; CHECK-LABEL: @test.ult.gep.shl.nonconst.zext( +; CHECK-NEXT: check.0.min: +; CHECK-NEXT: [[ADD_10:%.*]] = getelementptr inbounds i16, i16* [[SRC:%.*]], i16 10 +; CHECK-NEXT: [[C_ADD_10_MAX:%.*]] = icmp ugt i16* [[ADD_10]], [[MAX:%.*]] +; CHECK-NEXT: br i1 [[C_ADD_10_MAX]], label [[EXIT_1:%.*]], label [[CHECK_IDX:%.*]] +; CHECK: exit.1: +; CHECK-NEXT: ret i1 true +; CHECK: check.idx: +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[IDX:%.*]], 5 +; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_MAX:%.*]], label [[TRAP:%.*]] +; CHECK: check.max: +; CHECK-NEXT: [[IDX_SHL:%.*]] = shl nuw i16 [[IDX]], [[B:%.*]] +; CHECK-NEXT: [[EXT:%.*]] = zext i16 [[IDX_SHL]] to i64 +; CHECK-NEXT: [[ADD_PTR_SHL:%.*]] = getelementptr inbounds i16, i16* [[SRC]], i64 [[EXT]] +; CHECK-NEXT: [[C_MAX:%.*]] = icmp ult i16* [[ADD_PTR_SHL]], [[MAX]] +; CHECK-NEXT: ret i1 [[C_MAX]] +; CHECK: trap: +; CHECK-NEXT: [[IDX_SHL_1:%.*]] = shl nuw i16 [[IDX]], [[B]] +; CHECK-NEXT: [[EXT_1:%.*]] = zext i16 [[IDX_SHL_1]] to i64 +; CHECK-NEXT: [[ADD_PTR_SHL_1:%.*]] = getelementptr inbounds i16, i16* [[SRC]], i64 [[EXT_1]] +; CHECK-NEXT: [[C_MAX_1:%.*]] = icmp ult i16* [[ADD_PTR_SHL_1]], [[MAX]] +; CHECK-NEXT: ret i1 [[C_MAX_1]] +; +check.0.min: + %add.10 = getelementptr inbounds i16, i16* %src, i16 10 + %c.add.10.max = icmp ugt i16* %add.10, %max + br i1 %c.add.10.max, label %exit.1, label %check.idx + +exit.1: + ret i1 true + + +check.idx: ; preds = %check.0.min + %cmp = icmp ult i16 %idx, 5 + br i1 %cmp, label %check.max, label %trap + +check.max: ; preds = %check.idx + %idx.shl = shl nuw i16 %idx, %B + %ext = zext i16 %idx.shl to i64 + %add.ptr.shl = getelementptr inbounds i16, i16* %src, i64 %ext + %c.max = icmp ult i16* %add.ptr.shl, %max + ret i1 %c.max + +trap: ; preds = %check.idx, %check.0.min + %idx.shl.1 = shl nuw i16 %idx, %B + %ext.1 = zext i16 %idx.shl.1 to i64 + %add.ptr.shl.1 = getelementptr inbounds i16, i16* %src, i64 %ext.1 + %c.max.1 = icmp ult i16* %add.ptr.shl.1, %max + ret i1 %c.max.1 +} + declare void @use(i1)