diff --git a/lib/Transforms/Instrumentation/BlockProfiling.cpp b/lib/Transforms/Instrumentation/BlockProfiling.cpp index 7b38d70d9e3..8dacc92bc26 100644 --- a/lib/Transforms/Instrumentation/BlockProfiling.cpp +++ b/lib/Transforms/Instrumentation/BlockProfiling.cpp @@ -31,7 +31,7 @@ static void insertInitializationCall(Function *MainFn, const char *FnName, const Type *ArgVTy = PointerType::get(PointerType::get(Type::SByteTy)); const Type *UIntPtr = PointerType::get(Type::UIntTy); Module &M = *MainFn->getParent(); - Function *InitFn = M.getOrInsertFunction(FnName, Type::VoidTy, Type::IntTy, + Function *InitFn = M.getOrInsertFunction(FnName, Type::IntTy, Type::IntTy, ArgVTy, UIntPtr, Type::UIntTy, 0); // This could force argc and argv into programs that wouldn't otherwise have @@ -45,29 +45,6 @@ static void insertInitializationCall(Function *MainFn, const char *FnName, BasicBlock::iterator InsertPos = Entry->begin(); while (isa(InsertPos)) ++InsertPos; - Function::aiterator AI; - switch (MainFn->asize()) { - default: - case 2: - AI = MainFn->abegin(); ++AI; - if (AI->getType() != ArgVTy) { - Args[1] = new CastInst(AI, ArgVTy, "argv.cast", InsertPos); - } else { - Args[1] = AI; - } - - case 1: - AI = MainFn->abegin(); - if (AI->getType() != Type::IntTy) { - Args[0] = new CastInst(AI, Type::IntTy, "argc.cast", InsertPos); - } else { - Args[0] = AI; - } - - case 0: - break; - } - ConstantPointerRef *ArrayCPR = ConstantPointerRef::get(Array); std::vector GEPIndices(2, Constant::getNullValue(Type::LongTy)); Args[2] = ConstantExpr::getGetElementPtr(ArrayCPR, GEPIndices); @@ -76,7 +53,37 @@ static void insertInitializationCall(Function *MainFn, const char *FnName, cast(Array->getType()->getElementType())->getNumElements(); Args[3] = ConstantUInt::get(Type::UIntTy, NumElements); - new CallInst(InitFn, Args, "", InsertPos); + Instruction *InitCall = new CallInst(InitFn, Args, "newargc", InsertPos); + + // If argc or argv are not available in main, just pass null values in. + Function::aiterator AI; + switch (MainFn->asize()) { + default: + case 2: + AI = MainFn->abegin(); ++AI; + if (AI->getType() != ArgVTy) { + InitCall->setOperand(2, new CastInst(AI, ArgVTy, "argv.cast", InitCall)); + } else { + InitCall->setOperand(2, AI); + } + + case 1: + AI = MainFn->abegin(); + // If the program looked at argc, have it look at the return value of the + // init call instead. + if (AI->getType() != Type::IntTy) { + if (!AI->use_empty()) + AI->replaceAllUsesWith(new CastInst(InitCall, AI->getType(), "", + InsertPos)); + InitCall->setOperand(1, new CastInst(AI, Type::IntTy, "argc.cast", + InitCall)); + } else { + AI->replaceAllUsesWith(InitCall); + InitCall->setOperand(1, AI); + } + + case 0: break; + } } static void IncrementCounterInBlock(BasicBlock *BB, unsigned CounterNum,