diff options
Diffstat (limited to 'lib/Transforms/Scalar/PredicateSimplifier.cpp')
-rw-r--r-- | lib/Transforms/Scalar/PredicateSimplifier.cpp | 122 |
1 files changed, 90 insertions, 32 deletions
diff --git a/lib/Transforms/Scalar/PredicateSimplifier.cpp b/lib/Transforms/Scalar/PredicateSimplifier.cpp index a11d5bdfcd..a01ae7fd73 100644 --- a/lib/Transforms/Scalar/PredicateSimplifier.cpp +++ b/lib/Transforms/Scalar/PredicateSimplifier.cpp @@ -52,7 +52,7 @@ // responsible for analyzing the variable and seeing what new inferences // can be made from each property. For example: // -// %P = setne int* %ptr, null +// %P = icmp ne int* %ptr, null // %a = and bool %P, %Q // br bool %a label %cond_true, label %cond_false // @@ -140,14 +140,14 @@ namespace { static bool validPredicate(LatticeVal LV) { switch (LV) { - case GT: case GE: case LT: case LE: case NE: - case SGTULT: case SGT: case SGEULE: - case SLTUGT: case SLT: case SLEUGE: - case ULT: case UGT: - case SLE: case SGE: case ULE: case UGE: - return true; - default: - return false; + case GT: case GE: case LT: case LE: case NE: + case SGTULT: case SGT: case SGEULE: + case SLTUGT: case SLT: case SLEUGE: + case ULT: case UGT: + case SLE: case SGE: case ULE: case UGE: + return true; + default: + return false; } } @@ -415,7 +415,7 @@ namespace { if (iULT == end || iUGT == end) { if (iULT == end) iSLT = last; else iSLT = iULT; if (iUGT == end) iSGT = begin; else iSGT = iUGT; - } else if (iULT->first->getSExtValue() < 0) { + } else if (iULT->first->getSExtValue() < 0) { assert(iUGT->first->getSExtValue() >= 0 && "Bad sign comparison."); iSGT = iUGT; iSLT = iULT; @@ -424,7 +424,7 @@ namespace { iUGT->first->getSExtValue() < 0 && "Bad sign comparison."); iSGT = iULT; iSLT = iUGT; - } + } if (iSGT != end && iSGT->first->getSExtValue() < CI->getSExtValue()) iSGT = end; @@ -436,13 +436,13 @@ namespace { if (iSLT == end || begin->first->getSExtValue() > iSLT->first->getSExtValue()) iSLT = begin; - } + } if (last != end) { if (last->first->getSExtValue() > CI->getSExtValue()) if (iSGT == end || last->first->getSExtValue() < iSGT->first->getSExtValue()) iSGT = last; - } + } } if (iULT != end) addInequality(iULT->second, index, TreeRoot, ULT); @@ -868,7 +868,7 @@ namespace { if (n1) assert(V1 == IG.node(n1)->getValue() && "Value isn't canonical."); if (n2) assert(V2 == IG.node(n2)->getValue() && "Value isn't canonical."); - if (compare(V2, V1)) { std::swap(V1, V2); std::swap(n1, n2); } + assert(!compare(V2, V1) && "Please order parameters to makeEqual."); assert(!isa<Constant>(V2) && "Tried to remove a constant."); @@ -1398,10 +1398,22 @@ namespace { DEBUG(IG.dump()); - // TODO: actually check the constants and add to UB. - if (isa<Constant>(O.LHS) && isa<Constant>(O.RHS)) { - WorkList.pop_front(); - continue; + // If they're both Constant, skip it. Check for contradiction and mark + // the BB as unreachable if so. + if (Constant *CI_L = dyn_cast<Constant>(O.LHS)) { + if (Constant *CI_R = dyn_cast<Constant>(O.RHS)) { + if (ConstantExpr::getCompare(O.Op, CI_L, CI_R) == + ConstantInt::getFalse()) + UB.mark(TopBB); + + WorkList.pop_front(); + continue; + } + } + + if (compare(O.RHS, O.LHS)) { + std::swap(O.LHS, O.RHS); + O.Op = ICmpInst::getSwappedPredicate(O.Op); } if (O.Op == ICmpInst::ICMP_EQ) { @@ -1416,7 +1428,7 @@ namespace { UB.mark(TopBB); } else { if (isRelatedBy(O.LHS, O.RHS, ICmpInst::getInversePredicate(O.Op))){ - DOUT << "inequality contradiction!\n"; + UB.mark(TopBB); WorkList.pop_front(); continue; } @@ -1438,6 +1450,31 @@ namespace { continue; } + // Generalize %x u> -10 to %x > -10. + if (ConstantInt *CI = dyn_cast<ConstantInt>(O.RHS)) { + // xform doesn't apply to i1 + if (CI->getType()->getBitWidth() > 1) { + if (LV == SLT && CI->getSExtValue() < 0) { + // i8 %x s< -5 implies %x < -5 and %x u> 127 + + const IntegerType *Ty = CI->getType(); + LV = LT; + add(O.LHS, ConstantInt::get(Ty, Ty->getBitMask() >> 1), + ICmpInst::ICMP_UGT); + } else if (LV == SGT && CI->getSExtValue() >= 0) { + // i8 %x s> 5 implies %x > 5 and %x u< 128 + + const IntegerType *Ty = CI->getType(); + LV = LT; + add(O.LHS, ConstantInt::get(Ty, 1 << Ty->getBitWidth()), + ICmpInst::ICMP_ULT); + } else if (CI->getSExtValue() >= 0) { + if (LV == ULT || LV == SLT) LV = LT; + if (LV == UGT || LV == SGT) LV = GT; + } + } + } + IG.addInequality(n1, n2, Top, LV); if (Instruction *I1 = dyn_cast<Instruction>(O.LHS)) { @@ -1531,6 +1568,9 @@ namespace { void visitLoadInst(LoadInst &LI); void visitStoreInst(StoreInst &SI); + void visitSExtInst(SExtInst &SI); + void visitZExtInst(ZExtInst &ZI); + void visitBinaryOperator(BinaryOperator &BO); }; @@ -1730,23 +1770,41 @@ namespace { VRP.solve(); } + void PredicateSimplifier::Forwards::visitSExtInst(SExtInst &SI) { + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &SI); + const IntegerType *Ty = cast<IntegerType>(SI.getSrcTy()); + VRP.add(ConstantInt::get(SI.getDestTy(), ~(Ty->getBitMask() >> 1)), + &SI, ICmpInst::ICMP_SLE); + VRP.add(ConstantInt::get(SI.getDestTy(), (1 << (Ty->getBitWidth()-1)) - 1), + &SI, ICmpInst::ICMP_SGE); + VRP.solve(); + } + + void PredicateSimplifier::Forwards::visitZExtInst(ZExtInst &ZI) { + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &ZI); + const IntegerType *Ty = cast<IntegerType>(ZI.getSrcTy()); + VRP.add(ConstantInt::get(ZI.getDestTy(), Ty->getBitMask()), + &ZI, ICmpInst::ICMP_UGE); + VRP.solve(); + } + void PredicateSimplifier::Forwards::visitBinaryOperator(BinaryOperator &BO) { Instruction::BinaryOps ops = BO.getOpcode(); switch (ops) { - case Instruction::URem: - case Instruction::SRem: - case Instruction::UDiv: - case Instruction::SDiv: { - Value *Divisor = BO.getOperand(1); - VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &BO); - VRP.add(Constant::getNullValue(Divisor->getType()), Divisor, - ICmpInst::ICMP_NE); - VRP.solve(); - break; - } - default: - break; + case Instruction::URem: + case Instruction::SRem: + case Instruction::UDiv: + case Instruction::SDiv: { + Value *Divisor = BO.getOperand(1); + VRPSolver VRP(IG, UB, PS->Forest, PS->modified, &BO); + VRP.add(Constant::getNullValue(Divisor->getType()), Divisor, + ICmpInst::ICMP_NE); + VRP.solve(); + break; + } + default: + break; } } |