aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms
diff options
context:
space:
mode:
authorNick Lewycky <nicholas@mxc.ca>2008-03-06 06:48:30 +0000
committerNick Lewycky <nicholas@mxc.ca>2008-03-06 06:48:30 +0000
commitc1a2a612019ea1c764f3ccb5959104aea3d4df2f (patch)
tree76a41bfda9334dfc749c6ac1ad1ff2a370950834 /lib/Transforms
parent4cb8bd8effdc999128d9ab82e1b2fe860b01c556 (diff)
Don't try to simplify urem and srem using arithmetic rules that don't work
under modulo (overflow). Fixes PR1933. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@47987 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Transforms')
-rw-r--r--lib/Transforms/Scalar/InstructionCombining.cpp139
1 files changed, 96 insertions, 43 deletions
diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp
index 1000ba6036..8e99dcc7db 100644
--- a/lib/Transforms/Scalar/InstructionCombining.cpp
+++ b/lib/Transforms/Scalar/InstructionCombining.cpp
@@ -834,6 +834,49 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
return;
}
break;
+ case Instruction::SRem:
+ if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ APInt RA = Rem->getValue();
+ if (RA.isPowerOf2() || (-RA).isPowerOf2()) {
+ APInt LowBits = RA.isStrictlyPositive() ? ((RA - 1) | RA) : ~RA;
+ APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
+ ComputeMaskedBits(I->getOperand(0), Mask2,KnownZero2,KnownOne2,Depth+1);
+
+ // The sign of a remainder is equal to the sign of the first
+ // operand (zero being positive).
+ if (KnownZero2[BitWidth-1] || ((KnownZero2 & LowBits) == LowBits))
+ KnownZero2 |= ~LowBits;
+ else if (KnownOne2[BitWidth-1])
+ KnownOne2 |= ~LowBits;
+
+ KnownZero |= KnownZero2 & Mask;
+ KnownOne |= KnownOne2 & Mask;
+
+ assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+ }
+ }
+ break;
+ case Instruction::URem:
+ if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ APInt RA = Rem->getValue();
+ if (RA.isStrictlyPositive() && RA.isPowerOf2()) {
+ APInt LowBits = (RA - 1) | RA;
+ APInt Mask2 = LowBits & Mask;
+ KnownZero |= ~LowBits & Mask;
+ ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne,Depth+1);
+ assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+ }
+ } else {
+ // Since the result is less than or equal to RHS, any leading zero bits
+ // in RHS must also exist in the result.
+ APInt AllOnes = APInt::getAllOnesValue(BitWidth);
+ ComputeMaskedBits(I->getOperand(1), AllOnes, KnownZero2, KnownOne2, Depth+1);
+
+ uint32_t Leaders = KnownZero2.countLeadingOnes();
+ KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & Mask;
+ assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+ }
+ break;
}
}
@@ -1418,6 +1461,52 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, APInt DemandedMask,
}
}
break;
+ case Instruction::SRem:
+ if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ APInt RA = Rem->getValue();
+ if (RA.isPowerOf2() || (-RA).isPowerOf2()) {
+ APInt LowBits = RA.isStrictlyPositive() ? (RA - 1) | RA : ~RA;
+ APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
+ if (SimplifyDemandedBits(I->getOperand(0), Mask2,
+ LHSKnownZero, LHSKnownOne, Depth+1))
+ return true;
+
+ if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits))
+ LHSKnownZero |= ~LowBits;
+ else if (LHSKnownOne[BitWidth-1])
+ LHSKnownOne |= ~LowBits;
+
+ KnownZero |= LHSKnownZero & DemandedMask;
+ KnownOne |= LHSKnownOne & DemandedMask;
+
+ assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+ }
+ }
+ break;
+ case Instruction::URem:
+ if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ APInt RA = Rem->getValue();
+ if (RA.isPowerOf2()) {
+ APInt LowBits = (RA - 1) | RA;
+ APInt Mask2 = LowBits & DemandedMask;
+ KnownZero |= ~LowBits & DemandedMask;
+ if (SimplifyDemandedBits(I->getOperand(0), Mask2,
+ KnownZero, KnownOne, Depth+1))
+ return true;
+
+ assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+ }
+ } else {
+ APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);
+ APInt AllOnes = APInt::getAllOnesValue(BitWidth);
+ if (SimplifyDemandedBits(I->getOperand(1), AllOnes,
+ KnownZero2, KnownOne2, Depth+1))
+ return true;
+
+ uint32_t Leaders = KnownZero2.countLeadingOnes();
+ KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
+ }
+ break;
}
// If the client is only demanding bits that we know, return the known
@@ -2780,46 +2869,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
return commonDivTransforms(I);
}
-/// GetFactor - If we can prove that the specified value is at least a multiple
-/// of some factor, return that factor.
-static Constant *GetFactor(Value *V) {
- if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
- return CI;
-
- // Unless we can be tricky, we know this is a multiple of 1.
- Constant *Result = ConstantInt::get(V->getType(), 1);
-
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I) return Result;
-
- if (I->getOpcode() == Instruction::Mul) {
- // Handle multiplies by a constant, etc.
- return ConstantExpr::getMul(GetFactor(I->getOperand(0)),
- GetFactor(I->getOperand(1)));
- } else if (I->getOpcode() == Instruction::Shl) {
- // (X<<C) -> X * (1 << C)
- if (Constant *ShRHS = dyn_cast<Constant>(I->getOperand(1))) {
- ShRHS = ConstantExpr::getShl(Result, ShRHS);
- return ConstantExpr::getMul(GetFactor(I->getOperand(0)), ShRHS);
- }
- } else if (I->getOpcode() == Instruction::And) {
- if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
- // X & 0xFFF0 is known to be a multiple of 16.
- uint32_t Zeros = RHS->getValue().countTrailingZeros();
- if (Zeros != V->getType()->getPrimitiveSizeInBits())// don't shift by "32"
- return ConstantExpr::getShl(Result,
- ConstantInt::get(Result->getType(), Zeros));
- }
- } else if (CastInst *CI = dyn_cast<CastInst>(I)) {
- // Only handle int->int casts.
- if (!CI->isIntegerCast())
- return Result;
- Value *Op = CI->getOperand(0);
- return ConstantExpr::getCast(CI->getOpcode(), GetFactor(Op), V->getType());
- }
- return Result;
-}
-
/// This function implements the transforms on rem instructions that work
/// regardless of the kind of rem instruction it is (urem, srem, or frem). It
/// is used by the visitors to those instructions.
@@ -2901,9 +2950,13 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
if (Instruction *NV = FoldOpIntoPhi(I))
return NV;
}
- // (X * C1) % C2 --> 0 iff C1 % C2 == 0
- if (ConstantExpr::getSRem(GetFactor(Op0I), RHS)->isNullValue())
- return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+
+ // See if we can fold away this rem instruction.
+ uint32_t BitWidth = cast<IntegerType>(I.getType())->getBitWidth();
+ APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+ if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth),
+ KnownZero, KnownOne))
+ return &I;
}
}