diff options
-rw-r--r-- | lib/VMCore/ConstantFold.cpp | 115 | ||||
-rw-r--r-- | lib/VMCore/ConstantFold.h | 11 | ||||
-rw-r--r-- | lib/VMCore/Constants.cpp | 5 | ||||
-rw-r--r-- | test/Assembler/insertextractvalue.ll | 12 |
4 files changed, 127 insertions, 16 deletions
diff --git a/lib/VMCore/ConstantFold.cpp b/lib/VMCore/ConstantFold.cpp index 0913c481ad..069c99ac83 100644 --- a/lib/VMCore/ConstantFold.cpp +++ b/lib/VMCore/ConstantFold.cpp @@ -394,6 +394,7 @@ Constant *llvm::ConstantFoldInsertElementInstruction(const Constant *Val, } return ConstantVector::get(Ops); } + return 0; } @@ -447,18 +448,112 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(const Constant *V1, return ConstantVector::get(&Result[0], Result.size()); } -Constant *llvm::ConstantFoldExtractValue(const Constant *Agg, - Constant* const *Idxs, - unsigned NumIdx) { - // FIXME: implement some constant folds - return 0; +Constant *llvm::ConstantFoldExtractValueInstruction(const Constant *Agg, + const unsigned *Idxs, + unsigned NumIdx) { + // Base case: no indices, so return the entire value. + if (NumIdx == 0) + return const_cast<Constant *>(Agg); + + if (isa<UndefValue>(Agg)) // ev(undef, x) -> undef + return UndefValue::get(ExtractValueInst::getIndexedType(Agg->getType(), + Idxs, + Idxs + NumIdx)); + + if (isa<ConstantAggregateZero>(Agg)) // ev(0, x) -> 0 + return + Constant::getNullValue(ExtractValueInst::getIndexedType(Agg->getType(), + Idxs, + Idxs + NumIdx)); + + // Otherwise recurse. + return ConstantFoldExtractValueInstruction(Agg->getOperand(*Idxs), + Idxs+1, NumIdx-1); } -Constant *llvm::ConstantFoldInsertValue(const Constant *Agg, - const Constant *Val, - Constant* const *Idxs, - unsigned NumIdx) { - // FIXME: implement some constant folds +Constant *llvm::ConstantFoldInsertValueInstruction(const Constant *Agg, + const Constant *Val, + const unsigned *Idxs, + unsigned NumIdx) { + // Base case: no indices, so replace the entire value. + if (NumIdx == 0) + return const_cast<Constant *>(Val); + + if (isa<UndefValue>(Agg)) { + // Insertion of constant into aggregate undef + // Optimize away insertion of undef + if (isa<UndefValue>(Val)) + return const_cast<Constant*>(Agg); + // Otherwise break the aggregate undef into multiple undefs and do + // the insertion + const CompositeType *AggTy = cast<CompositeType>(Agg->getType()); + unsigned numOps; + if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy)) + numOps = AR->getNumElements(); + else + numOps = cast<StructType>(AggTy)->getNumElements(); + std::vector<Constant*> Ops(numOps); + for (unsigned i = 0; i < numOps; ++i) { + const Type *MemberTy = AggTy->getTypeAtIndex(i); + const Constant *Op = + (*Idxs == i) ? + ConstantFoldInsertValueInstruction(UndefValue::get(MemberTy), + Val, Idxs+1, NumIdx-1) : + UndefValue::get(MemberTy); + Ops[i] = const_cast<Constant*>(Op); + } + if (isa<StructType>(AggTy)) + return ConstantStruct::get(Ops); + else + return ConstantArray::get(cast<ArrayType>(AggTy), Ops); + } + if (isa<ConstantAggregateZero>(Agg)) { + // Insertion of constant into aggregate zero + // Optimize away insertion of zero + if (Val->isNullValue()) + return const_cast<Constant*>(Agg); + // Otherwise break the aggregate zero into multiple zeros and do + // the insertion + const CompositeType *AggTy = cast<CompositeType>(Agg->getType()); + unsigned numOps; + if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy)) + numOps = AR->getNumElements(); + else + numOps = cast<StructType>(AggTy)->getNumElements(); + std::vector<Constant*> Ops(numOps); + for (unsigned i = 0; i < numOps; ++i) { + const Type *MemberTy = AggTy->getTypeAtIndex(i); + const Constant *Op = + (*Idxs == i) ? + ConstantFoldInsertValueInstruction(Constant::getNullValue(MemberTy), + Val, Idxs+1, NumIdx-1) : + Constant::getNullValue(MemberTy); + Ops[i] = const_cast<Constant*>(Op); + } + if (isa<StructType>(AggTy)) + return ConstantStruct::get(Ops); + else + return ConstantArray::get(cast<ArrayType>(AggTy), Ops); + } + if (isa<ConstantStruct>(Agg) || isa<ConstantArray>(Agg)) { + // Insertion of constant into aggregate constant + std::vector<Constant*> Ops(Agg->getNumOperands()); + for (unsigned i = 0; i < Agg->getNumOperands(); ++i) { + const Constant *Op = + (*Idxs == i) ? + ConstantFoldInsertValueInstruction(Agg->getOperand(i), + Val, Idxs+1, NumIdx-1) : + Agg->getOperand(i); + Ops[i] = const_cast<Constant*>(Op); + } + Constant *C; + if (isa<StructType>(Agg->getType())) + C = ConstantStruct::get(Ops); + else + C = ConstantArray::get(cast<ArrayType>(Agg->getType()), Ops); + return C; + } + return 0; } diff --git a/lib/VMCore/ConstantFold.h b/lib/VMCore/ConstantFold.h index bfa6f289d9..fddee23765 100644 --- a/lib/VMCore/ConstantFold.h +++ b/lib/VMCore/ConstantFold.h @@ -41,10 +41,13 @@ namespace llvm { Constant *ConstantFoldShuffleVectorInstruction(const Constant *V1, const Constant *V2, const Constant *Mask); - Constant *ConstantFoldExtractValue(const Constant *Agg, - Constant* const *Idxs, unsigned NumIdx); - Constant *ConstantFoldInsertValue(const Constant *Agg, const Constant *Val, - Constant* const *Idxs, unsigned NumIdx); + Constant *ConstantFoldExtractValueInstruction(const Constant *Agg, + const unsigned *Idxs, + unsigned NumIdx); + Constant *ConstantFoldInsertValueInstruction(const Constant *Agg, + const Constant *Val, + const unsigned* Idxs, + unsigned NumIdx); Constant *ConstantFoldBinaryInstruction(unsigned Opcode, const Constant *V1, const Constant *V2); Constant *ConstantFoldCompareInstruction(unsigned short predicate, diff --git a/lib/VMCore/Constants.cpp b/lib/VMCore/Constants.cpp index 9c10e75c8d..530f7ba61a 100644 --- a/lib/VMCore/Constants.cpp +++ b/lib/VMCore/Constants.cpp @@ -2305,9 +2305,10 @@ Constant *ConstantExpr::getInsertValueTy(const Type *ReqTy, Constant *Agg, "insertvalue indices invalid!"); assert(Agg->getType() == ReqTy && "insertvalue type invalid!"); - assert(Agg->getType()->isFirstClassType() && "Non-first-class type for constant InsertValue expression"); + if (Constant *FC = ConstantFoldInsertValueInstruction(Agg, Val, Idxs, NumIdx)) + return FC; // Fold a few common cases... // Look up the constant in the table first to ensure uniqueness std::vector<Constant*> ArgVec; ArgVec.push_back(Agg); @@ -2336,6 +2337,8 @@ Constant *ConstantExpr::getExtractValueTy(const Type *ReqTy, Constant *Agg, "extractvalue indices invalid!"); assert(Agg->getType()->isFirstClassType() && "Non-first-class type for constant extractvalue expression"); + if (Constant *FC = ConstantFoldExtractValueInstruction(Agg, Idxs, NumIdx)) + return FC; // Fold a few common cases... // Look up the constant in the table first to ensure uniqueness std::vector<Constant*> ArgVec; ArgVec.push_back(Agg); diff --git a/test/Assembler/insertextractvalue.ll b/test/Assembler/insertextractvalue.ll index bdd0932a1d..0da0b77b72 100644 --- a/test/Assembler/insertextractvalue.ll +++ b/test/Assembler/insertextractvalue.ll @@ -1,4 +1,6 @@ -; RUN: llvm-as < %s +; RUN: llvm-as < %s | llvm-dis > %t +; RUN: grep insertvalue %t | count 1 +; RUN: grep extractvalue %t | count 1 define float @foo({{i32},{float, double}}* %p) { %t = load {{i32},{float, double}}* %p @@ -11,3 +13,11 @@ define float @bar({{i32},{float, double}}* %p) { store {{i32},{float, double}} insertvalue ({{i32},{float, double}}{{i32}{i32 4},{float, double}{float 4.0, double 5.0}}, double 20.0, 1, 1), {{i32},{float, double}}* %p ret float extractvalue ({{i32},{float, double}}{{i32}{i32 3},{float, double}{float 7.0, double 9.0}}, 1, 0) } +define float @car({{i32},{float, double}}* %p) { + store {{i32},{float, double}} insertvalue ({{i32},{float, double}} undef, double 20.0, 1, 1), {{i32},{float, double}}* %p + ret float extractvalue ({{i32},{float, double}} undef, 1, 0) +} +define float @dar({{i32},{float, double}}* %p) { + store {{i32},{float, double}} insertvalue ({{i32},{float, double}} zeroinitializer, double 20.0, 1, 1), {{i32},{float, double}}* %p + ret float extractvalue ({{i32},{float, double}} zeroinitializer, 1, 0) +} |