aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/NaCl/ExpandMulWithOverflow.cpp
blob: 171dda1f09290fd06a04c45caf73503b73b2c402 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//===- ExpandMulWithOverflow.cpp - Expand out usage of umul.with.overflow--===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// The llvm.*.with.overflow.*() intrinsics are awkward for PNaCl
// support because they return structs, and we want to omit struct
// types from IR in PNaCl's stable ABI.
//
// However, llvm.umul.with.overflow.*() is used by Clang to implement
// an overflow check for C++'s new[] operator.  This pass expands out
// these uses so that PNaCl does not have to support
// umul.with.overflow as part of PNaCl's stable ABI.
//
// This pass only handles multiplication by a constant, which is the
// only case of umul.with.overflow that is currently generated by
// Clang (unless '-ftrapv' is passed to Clang).
//
// X * Const overflows iff X > UINT_MAX / Const, where UINT_MAX is the
// maximum value for the integer type being multiplied.
//
//===----------------------------------------------------------------------===//

#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/NaCl.h"

using namespace llvm;

namespace {
  // This is a ModulePass so that the pass can easily iterate over all
  // uses of the intrinsics.
  class ExpandMulWithOverflow : public ModulePass {
  public:
    static char ID; // Pass identification, replacement for typeid
    ExpandMulWithOverflow() : ModulePass(ID) {
      initializeExpandMulWithOverflowPass(*PassRegistry::getPassRegistry());
    }

    virtual bool runOnModule(Module &M);
  };
}

char ExpandMulWithOverflow::ID = 0;
INITIALIZE_PASS(ExpandMulWithOverflow, "expand-mul-with-overflow",
                "Expand out uses of llvm.umul.with.overflow intrinsics",
                false, false)

static uint64_t UintTypeMax(unsigned Bits) {
  // Avoid doing 1 << 64 because that is undefined on a uint64_t.
  if (Bits == 64)
    return ~(uint64_t) 0;
  return (((uint64_t) 1) << Bits) - 1;
}

static bool ExpandForIntSize(Module *M, unsigned Bits) {
  IntegerType *IntTy = IntegerType::get(M->getContext(), Bits);
  SmallVector<Type *, 1> Types;
  Types.push_back(IntTy);
  std::string Name = Intrinsic::getName(Intrinsic::umul_with_overflow, Types);
  Function *Intrinsic = M->getFunction(Name);
  if (!Intrinsic)
    return false;
  for (Value::use_iterator CallIter = Intrinsic->use_begin(),
         E = Intrinsic->use_end(); CallIter != E; ) {
    CallInst *Call = dyn_cast<CallInst>(*CallIter++);
    if (!Call) {
      report_fatal_error("ExpandMulWithOverflow: Taking the address of a "
                         "umul.with.overflow intrinsic is not allowed");
    }
    Value *VariableArg;
    ConstantInt *ConstantArg;
    if (ConstantInt *C = dyn_cast<ConstantInt>(Call->getArgOperand(0))) {
      VariableArg = Call->getArgOperand(1);
      ConstantArg = C;
    } else if (ConstantInt *C = dyn_cast<ConstantInt>(Call->getArgOperand(1))) {
      VariableArg = Call->getArgOperand(0);
      ConstantArg = C;
    } else {
      errs() << "Use: " << *Call << "\n";
      report_fatal_error("ExpandMulWithOverflow: At least one argument of "
                         "umul.with.overflow must be a constant");
    }

    Value *Mul = BinaryOperator::Create(
        Instruction::Mul, VariableArg, ConstantArg,
        Call->getName() + ".mul", Call);

    uint64_t ArgMax = UintTypeMax(Bits) / ConstantArg->getZExtValue();
    Value *Overflow = new ICmpInst(
        Call, CmpInst::ICMP_UGT, VariableArg, ConstantInt::get(IntTy, ArgMax),
        Call->getName() + ".overflow");

    for (Value::use_iterator FieldIter = Call->use_begin(),
           E = Call->use_end(); FieldIter != E; ) {
      User *U = *FieldIter++;
      ExtractValueInst *Field = dyn_cast<ExtractValueInst>(U);
      if (!Field) {
        errs() << "Use: " << *U << "\n";
        report_fatal_error(
            "ExpandMulWithOverflow: Use is not an extractvalue");
      }
      if (Field->getNumIndices() != 1) {
        report_fatal_error("ExpandMulWithOverflow: Unexpected indices");
      }
      unsigned Index = Field->getIndices()[0];
      if (Index == 0) {
        Field->replaceAllUsesWith(Mul);
      } else if (Index == 1) {
        Field->replaceAllUsesWith(Overflow);
      } else {
        report_fatal_error("ExpandMulWithOverflow: Unexpected index");
      }
      Field->eraseFromParent();
    }
    Call->eraseFromParent();
  }
  Intrinsic->eraseFromParent();
  return true;
}

bool ExpandMulWithOverflow::runOnModule(Module &M) {
  bool Modified = false;
  Modified |= ExpandForIntSize(&M, 64);
  Modified |= ExpandForIntSize(&M, 32);
  Modified |= ExpandForIntSize(&M, 16);
  Modified |= ExpandForIntSize(&M, 8);
  return Modified;
}

ModulePass *llvm::createExpandMulWithOverflowPass() {
  return new ExpandMulWithOverflow();
}