diff options
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 5 | ||||
-rw-r--r-- | lib/Transforms/Scalar/InstructionCombining.cpp | 144 | ||||
-rw-r--r-- | test/Transforms/InstCombine/adjust-for-sminmax.ll | 85 | ||||
-rw-r--r-- | test/Transforms/InstCombine/preserve-sminmax.ll | 22 |
4 files changed, 222 insertions, 34 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 60980d26aa..4e85e90e9e 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1831,11 +1831,6 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) { case Instruction::Select: // This could be a smax or umax that was lowered earlier. // Try to recover it. - // - // FIXME: This doesn't recognize code like this: - // %t = icmp sgt i32 %n, -1 - // %max = select i1 %t, i32 %n, i32 0 - // if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) { Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 1cb4fa2c21..e68e646b80 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -219,7 +219,8 @@ namespace { Instruction *visitBitCast(BitCastInst &CI); Instruction *FoldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); - Instruction *visitSelectInst(SelectInst &CI); + Instruction *visitSelectInst(SelectInst &SI); + Instruction *visitSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); Instruction *visitCallInst(CallInst &CI); Instruction *visitInvokeInst(InvokeInst &II); Instruction *visitPHINode(PHINode &PN); @@ -5312,8 +5313,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } - // See if we are doing a comparison between a constant and an instruction that - // can be folded into the comparison. + // See if we are doing a comparison with a constant. if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { Value *A, *B; @@ -5324,9 +5324,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return new ICmpInst(I.getPredicate(), A, B); } - // If we have a icmp le or icmp ge instruction, turn it into the appropriate - // icmp lt or icmp gt instruction. This allows us to rely on them being - // folded in the code below. + // If we have an icmp le or icmp ge instruction, turn it into the + // appropriate icmp lt or icmp gt instruction. This allows us to rely on + // them being folded in the code below. switch (I.getPredicate()) { default: break; case ICmpInst::ICMP_ULE: @@ -5446,7 +5446,24 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI)); break; } - + } + + // Test if the ICmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + if (I.hasOneUse()) + if (SelectInst *SI = dyn_cast<SelectInst>(*I.use_begin())) + if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || + (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + return 0; + + // See if we are doing a comparison between a constant and an instruction that + // can be folded into the comparison. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { // Since the RHS is a ConstantInt (CI), if the left hand side is an // instruction, see if that instruction also has constants so that the // instruction can be folded into the icmp @@ -8181,6 +8198,91 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, return 0; } +/// visitSelectInstWithICmp - Visit a SelectInst that has an +/// ICmpInst as its first operand. +/// +Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, + ICmpInst *ICI) { + bool Changed = false; + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + // Check cases where the comparison is with a constant that + // can be adjusted to fit the min/max idiom. We may edit ICI in + // place here, so make sure the select is the only user. + if (ICI->hasOneUse()) + if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) + switch (Pred) { + default: break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: { + // X < MIN ? T : F --> F + if (CI->isMinValue(Pred == ICmpInst::ICMP_SLT)) + return ReplaceInstUsesWith(SI, FalseVal); + // X < C ? X : C-1 --> X > C-1 ? C-1 : X + Constant *AdjustedRHS = SubOne(CI); + if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + ICI->setPredicate(Pred); + ICI->setOperand(1, CmpRHS); + SI.setOperand(1, TrueVal); + SI.setOperand(2, FalseVal); + Changed = true; + } + break; + } + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: { + // X > MAX ? T : F --> F + if (CI->isMaxValue(Pred == ICmpInst::ICMP_SGT)) + return ReplaceInstUsesWith(SI, FalseVal); + // X > C ? X : C+1 --> X < C+1 ? C+1 : X + Constant *AdjustedRHS = AddOne(CI); + if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + ICI->setPredicate(Pred); + ICI->setOperand(1, CmpRHS); + SI.setOperand(1, TrueVal); + SI.setOperand(2, FalseVal); + Changed = true; + } + break; + } + } + + if (CmpLHS == TrueVal && CmpRHS == FalseVal) { + // Transform (X == Y) ? X : Y -> Y + if (Pred == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? X : Y -> X + if (Pred == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(SI, TrueVal); + /// NOTE: if we wanted to, this is where to detect integer MIN/MAX + + } else if (CmpLHS == FalseVal && CmpRHS == TrueVal) { + // Transform (X == Y) ? Y : X -> X + if (Pred == ICmpInst::ICMP_EQ) + return ReplaceInstUsesWith(SI, FalseVal); + // Transform (X != Y) ? Y : X -> Y + if (Pred == ICmpInst::ICMP_NE) + return ReplaceInstUsesWith(SI, TrueVal); + /// NOTE: if we wanted to, this is where to detect integer MIN/MAX + } + + /// NOTE: if we wanted to, this is where to detect integer ABS + + return Changed ? &SI : 0; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -8329,7 +8431,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // Transform (X != Y) ? X : Y -> X if (FCI->getPredicate() == FCmpInst::FCMP_ONE) return ReplaceInstUsesWith(SI, TrueVal); - // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. + // NOTE: if we wanted to, this is where to detect MIN/MAX } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ // Transform (X == Y) ? Y : X -> X @@ -8347,31 +8449,15 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // Transform (X != Y) ? Y : X -> Y if (FCI->getPredicate() == FCmpInst::FCMP_ONE) return ReplaceInstUsesWith(SI, TrueVal); - // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. + // NOTE: if we wanted to, this is where to detect MIN/MAX } + // NOTE: if we wanted to, this is where to detect ABS } // See if we are selecting two values based on a comparison of the two values. - if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal)) { - if (ICI->getOperand(0) == TrueVal && ICI->getOperand(1) == FalseVal) { - // Transform (X == Y) ? X : Y -> Y - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(SI, FalseVal); - // Transform (X != Y) ? X : Y -> X - if (ICI->getPredicate() == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(SI, TrueVal); - // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. - - } else if (ICI->getOperand(0) == FalseVal && ICI->getOperand(1) == TrueVal){ - // Transform (X == Y) ? Y : X -> X - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(SI, FalseVal); - // Transform (X != Y) ? Y : X -> Y - if (ICI->getPredicate() == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(SI, TrueVal); - // NOTE: if we wanted to, this is where to detect MIN/MAX/ABS/etc. - } - } + if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal)) + if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) + return Result; if (Instruction *TI = dyn_cast<Instruction>(TrueVal)) if (Instruction *FI = dyn_cast<Instruction>(FalseVal)) diff --git a/test/Transforms/InstCombine/adjust-for-sminmax.ll b/test/Transforms/InstCombine/adjust-for-sminmax.ll new file mode 100644 index 0000000000..9328ad3649 --- /dev/null +++ b/test/Transforms/InstCombine/adjust-for-sminmax.ll @@ -0,0 +1,85 @@ +; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep {icmp s\[lg\]t i32 %n, 0} | count 16 + +; Instcombine should recognize that this code can be adjusted +; to fit the canonical smax/smin pattern. + +define i32 @floor_a(i32 %n) { + %t = icmp sgt i32 %n, -1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_a(i32 %n) { + %t = icmp slt i32 %n, 1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @floor_b(i32 %n) { + %t = icmp sgt i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_b(i32 %n) { + %t = icmp slt i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @floor_c(i32 %n) { + %t = icmp sge i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_c(i32 %n) { + %t = icmp sle i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @floor_d(i32 %n) { + %t = icmp sge i32 %n, 1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_d(i32 %n) { + %t = icmp sle i32 %n, -1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @floor_e(i32 %n) { + %t = icmp sgt i32 %n, -1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_e(i32 %n) { + %t = icmp slt i32 %n, 1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @floor_f(i32 %n) { + %t = icmp sgt i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_f(i32 %n) { + %t = icmp slt i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @floor_g(i32 %n) { + %t = icmp sge i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_g(i32 %n) { + %t = icmp sle i32 %n, 0 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @floor_h(i32 %n) { + %t = icmp sge i32 %n, 1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} +define i32 @ceil_h(i32 %n) { + %t = icmp sle i32 %n, -1 + %m = select i1 %t, i32 %n, i32 0 + ret i32 %m +} diff --git a/test/Transforms/InstCombine/preserve-sminmax.ll b/test/Transforms/InstCombine/preserve-sminmax.ll new file mode 100644 index 0000000000..24fb7dabe3 --- /dev/null +++ b/test/Transforms/InstCombine/preserve-sminmax.ll @@ -0,0 +1,22 @@ +; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep { i32 \[%\]sd, \[\[:alnum:\]\]* \\?1\\>} | count 4 + +; Instcombine normally would fold the sdiv into the comparison, +; making "icmp slt i32 %h, 2", but in this case the sdiv has +; another use, so it wouldn't a big win, and it would also +; obfuscate an otherise obvious smax pattern to the point where +; other analyses wouldn't recognize it. + +define i32 @foo(i32 %h) { + %sd = sdiv i32 %h, 2 + %t = icmp slt i32 %sd, 1 + %r = select i1 %t, i32 %sd, i32 1 + ret i32 %r +} + +define i32 @bar(i32 %h) { + %sd = sdiv i32 %h, 2 + %t = icmp sgt i32 %sd, 1 + %r = select i1 %t, i32 %sd, i32 1 + ret i32 %r +} + |