1 /* 2 * Copyright 2016 laf-intel 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * https://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <stdio.h> 18 #include <stdlib.h> 19 #include <unistd.h> 20 21 #include <list> 22 #include <string> 23 #include <fstream> 24 #include <sys/time.h> 25 #include "llvm/Config/llvm-config.h" 26 27 #include "llvm/ADT/Statistic.h" 28 #include "llvm/IR/IRBuilder.h" 29 #if LLVM_MAJOR >= 11 /* use new pass manager */ 30 #include "llvm/Passes/PassPlugin.h" 31 #include "llvm/Passes/PassBuilder.h" 32 #include "llvm/IR/PassManager.h" 33 #else 34 #include "llvm/IR/LegacyPassManager.h" 35 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 36 #endif 37 #include "llvm/IR/Module.h" 38 #include "llvm/Support/Debug.h" 39 #include "llvm/Support/raw_ostream.h" 40 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 41 #include "llvm/Pass.h" 42 #include "llvm/Analysis/ValueTracking.h" 43 #if LLVM_VERSION_MAJOR >= 14 /* how about stable interfaces? */ 44 #include "llvm/Passes/OptimizationLevel.h" 45 #endif 46 47 #if LLVM_VERSION_MAJOR >= 4 || \ 48 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 49 #include "llvm/IR/Verifier.h" 50 #include "llvm/IR/DebugInfo.h" 51 #else 52 #include "llvm/Analysis/Verifier.h" 53 #include "llvm/DebugInfo.h" 54 #define nullptr 0 55 #endif 56 57 #include <set> 58 #include "afl-llvm-common.h" 59 60 using namespace llvm; 61 62 namespace { 63 64 #if LLVM_MAJOR >= 11 /* use new pass manager */ 65 class CompareTransform : public PassInfoMixin<CompareTransform> { 66 67 public: CompareTransform()68 CompareTransform() { 69 70 #else 71 class CompareTransform : public ModulePass { 72 73 public: 74 static char ID; 75 CompareTransform() : ModulePass(ID) { 76 77 #endif 78 79 initInstrumentList(); 80 81 } 82 83 #if LLVM_MAJOR < 11 84 #if LLVM_VERSION_MAJOR >= 4 85 StringRef getPassName() const override { 86 87 #else 88 const char *getPassName() const override { 89 90 #endif 91 92 return "cmplog transform"; 93 94 } 95 96 #endif 97 98 #if LLVM_MAJOR >= 11 /* use new pass manager */ 99 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 100 #else 101 bool runOnModule(Module &M) override; 102 #endif 103 104 private: 105 bool transformCmps(Module &M, const bool processStrcmp, 106 const bool processMemcmp, const bool processStrncmp, 107 const bool processStrcasecmp, 108 const bool processStrncasecmp); 109 110 }; 111 112 } // namespace 113 114 #if LLVM_MAJOR >= 11 /* use new pass manager */ 115 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 116 llvmGetPassPluginInfo() { 117 118 return {LLVM_PLUGIN_API_VERSION, "comparetransform", "v0.1", 119 /* lambda to insert our pass into the pass pipeline. */ 120 [](PassBuilder &PB) { 121 122 #if 1 123 #if LLVM_VERSION_MAJOR <= 13 124 using OptimizationLevel = typename PassBuilder::OptimizationLevel; 125 #endif 126 PB.registerOptimizerLastEPCallback( 127 [](ModulePassManager &MPM, OptimizationLevel OL) { 128 129 MPM.addPass(CompareTransform()); 130 131 }); 132 133 /* TODO LTO registration */ 134 #else 135 using PipelineElement = typename PassBuilder::PipelineElement; 136 PB.registerPipelineParsingCallback([](StringRef Name, 137 ModulePassManager &MPM, 138 ArrayRef<PipelineElement>) { 139 140 if (Name == "comparetransform") { 141 142 MPM.addPass(CompareTransform()); 143 return true; 144 145 } else { 146 147 return false; 148 149 } 150 151 }); 152 153 #endif 154 155 }}; 156 157 } 158 159 #else 160 char CompareTransform::ID = 0; 161 #endif 162 163 bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, 164 const bool processMemcmp, 165 const bool processStrncmp, 166 const bool processStrcasecmp, 167 const bool processStrncasecmp) { 168 169 DenseMap<Value *, std::string *> valueMap; 170 std::vector<CallInst *> calls; 171 LLVMContext &C = M.getContext(); 172 IntegerType *Int1Ty = IntegerType::getInt1Ty(C); 173 IntegerType *Int8Ty = IntegerType::getInt8Ty(C); 174 IntegerType *Int32Ty = IntegerType::getInt32Ty(C); 175 IntegerType *Int64Ty = IntegerType::getInt64Ty(C); 176 177 #if LLVM_VERSION_MAJOR >= 9 178 FunctionCallee tolowerFn; 179 #else 180 Function *tolowerFn; 181 #endif 182 { 183 184 #if LLVM_VERSION_MAJOR >= 9 185 FunctionCallee 186 #else 187 Constant * 188 #endif 189 c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty 190 #if LLVM_VERSION_MAJOR < 5 191 , 192 NULL 193 #endif 194 ); 195 #if LLVM_VERSION_MAJOR >= 9 196 tolowerFn = c; 197 #else 198 tolowerFn = cast<Function>(c); 199 #endif 200 201 } 202 203 /* iterate over all functions, bbs and instruction and add suitable calls to 204 * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ 205 for (auto &F : M) { 206 207 if (!isInInstrumentList(&F, MNAME)) continue; 208 209 for (auto &BB : F) { 210 211 for (auto &IN : BB) { 212 213 CallInst *callInst = nullptr; 214 215 if ((callInst = dyn_cast<CallInst>(&IN))) { 216 217 bool isStrcmp = processStrcmp; 218 bool isMemcmp = processMemcmp; 219 bool isStrncmp = processStrncmp; 220 bool isStrcasecmp = processStrcasecmp; 221 bool isStrncasecmp = processStrncasecmp; 222 bool isIntMemcpy = true; 223 224 Function *Callee = callInst->getCalledFunction(); 225 if (!Callee) continue; 226 if (callInst->getCallingConv() != llvm::CallingConv::C) continue; 227 StringRef FuncName = Callee->getName(); 228 isStrcmp &= 229 (!FuncName.compare("strcmp") || !FuncName.compare("xmlStrcmp") || 230 !FuncName.compare("xmlStrEqual") || 231 !FuncName.compare("curl_strequal") || 232 !FuncName.compare("strcsequal") || 233 !FuncName.compare("g_strcmp0")); 234 isMemcmp &= 235 (!FuncName.compare("memcmp") || !FuncName.compare("bcmp") || 236 !FuncName.compare("CRYPTO_memcmp") || 237 !FuncName.compare("OPENSSL_memcmp") || 238 !FuncName.compare("memcmp_const_time") || 239 !FuncName.compare("memcmpct")); 240 isStrncmp &= (!FuncName.compare("strncmp") || 241 !FuncName.compare("curl_strnequal") || 242 !FuncName.compare("xmlStrncmp")); 243 isStrcasecmp &= (!FuncName.compare("strcasecmp") || 244 !FuncName.compare("stricmp") || 245 !FuncName.compare("ap_cstr_casecmp") || 246 !FuncName.compare("OPENSSL_strcasecmp") || 247 !FuncName.compare("xmlStrcasecmp") || 248 !FuncName.compare("g_strcasecmp") || 249 !FuncName.compare("g_ascii_strcasecmp") || 250 !FuncName.compare("Curl_strcasecompare") || 251 !FuncName.compare("Curl_safe_strcasecompare") || 252 !FuncName.compare("cmsstrcasecmp")); 253 isStrncasecmp &= (!FuncName.compare("strncasecmp") || 254 !FuncName.compare("strnicmp") || 255 !FuncName.compare("ap_cstr_casecmpn") || 256 !FuncName.compare("OPENSSL_strncasecmp") || 257 !FuncName.compare("xmlStrncasecmp") || 258 !FuncName.compare("g_ascii_strncasecmp") || 259 !FuncName.compare("Curl_strncasecompare") || 260 !FuncName.compare("g_strncasecmp")); 261 isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64"); 262 263 if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && 264 !isStrncasecmp && !isIntMemcpy) 265 continue; 266 267 /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function 268 * prototype */ 269 FunctionType *FT = Callee->getFunctionType(); 270 271 isStrcmp &= 272 FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && 273 FT->getParamType(0) == FT->getParamType(1) && 274 FT->getParamType(0) == 275 IntegerType::getInt8Ty(M.getContext())->getPointerTo(0); 276 isStrcasecmp &= 277 FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && 278 FT->getParamType(0) == FT->getParamType(1) && 279 FT->getParamType(0) == 280 IntegerType::getInt8Ty(M.getContext())->getPointerTo(0); 281 isMemcmp &= FT->getNumParams() == 3 && 282 FT->getReturnType()->isIntegerTy(32) && 283 FT->getParamType(0)->isPointerTy() && 284 FT->getParamType(1)->isPointerTy() && 285 FT->getParamType(2)->isIntegerTy(); 286 isStrncmp &= 287 FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && 288 FT->getParamType(0) == FT->getParamType(1) && 289 FT->getParamType(0) == 290 IntegerType::getInt8Ty(M.getContext())->getPointerTo(0) && 291 FT->getParamType(2)->isIntegerTy(); 292 isStrncasecmp &= 293 FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && 294 FT->getParamType(0) == FT->getParamType(1) && 295 FT->getParamType(0) == 296 IntegerType::getInt8Ty(M.getContext())->getPointerTo(0) && 297 FT->getParamType(2)->isIntegerTy(); 298 299 if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && 300 !isStrncasecmp && !isIntMemcpy) 301 continue; 302 303 /* is a str{n,}{case,}cmp/memcmp, check if we have 304 * str{case,}cmp(x, "const") or str{case,}cmp("const", x) 305 * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..) 306 * memcmp(x, "const", ..) or memcmp("const", x, ..) */ 307 Value *Str1P = callInst->getArgOperand(0), 308 *Str2P = callInst->getArgOperand(1); 309 StringRef Str1, Str2; 310 bool HasStr1 = getConstantStringInfo(Str1P, Str1); 311 bool HasStr2 = getConstantStringInfo(Str2P, Str2); 312 313 if (isIntMemcpy && HasStr2) { 314 315 valueMap[Str1P] = new std::string(Str2.str()); 316 // fprintf(stderr, "saved %s for %p\n", Str2.str().c_str(), Str1P); 317 continue; 318 319 } 320 321 // not literal? maybe global or local variable 322 if (!(HasStr1 || HasStr2)) { 323 324 auto *Ptr = dyn_cast<ConstantExpr>(Str2P); 325 if (Ptr && Ptr->getOpcode() == Instruction::GetElementPtr) { 326 327 if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) { 328 329 if (Var->hasInitializer()) { 330 331 if (auto *Array = 332 dyn_cast<ConstantDataArray>(Var->getInitializer())) { 333 334 HasStr2 = true; 335 Str2 = Array->getRawDataValues(); 336 valueMap[Str2P] = new std::string(Str2.str()); 337 // fprintf(stderr, "glo2 %s\n", Str2.str().c_str()); 338 339 } 340 341 } 342 343 } 344 345 } 346 347 if (!HasStr2) { 348 349 Ptr = dyn_cast<ConstantExpr>(Str1P); 350 if (Ptr && Ptr->getOpcode() == Instruction::GetElementPtr) { 351 352 if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) { 353 354 if (Var->hasInitializer()) { 355 356 if (auto *Array = dyn_cast<ConstantDataArray>( 357 Var->getInitializer())) { 358 359 HasStr1 = true; 360 Str1 = Array->getRawDataValues(); 361 valueMap[Str1P] = new std::string(Str1.str()); 362 // fprintf(stderr, "glo1 %s\n", Str1.str().c_str()); 363 364 } 365 366 } 367 368 } 369 370 } 371 372 } else if (isIntMemcpy) { 373 374 valueMap[Str1P] = new std::string(Str2.str()); 375 // fprintf(stderr, "saved\n"); 376 377 } 378 379 } 380 381 if (isIntMemcpy) continue; 382 383 if (!(HasStr1 || HasStr2)) { 384 385 // do we have a saved local variable initialization? 386 std::string *val = valueMap[Str1P]; 387 if (val && !val->empty()) { 388 389 Str1 = StringRef(*val); 390 HasStr1 = true; 391 // fprintf(stderr, "loaded1 %s\n", Str1.str().c_str()); 392 393 } else { 394 395 val = valueMap[Str2P]; 396 if (val && !val->empty()) { 397 398 Str2 = StringRef(*val); 399 HasStr2 = true; 400 // fprintf(stderr, "loaded2 %s\n", Str2.str().c_str()); 401 402 } 403 404 } 405 406 } 407 408 /* handle cases of one string is const, one string is variable */ 409 if (!(HasStr1 || HasStr2)) continue; 410 411 if (isMemcmp || isStrncmp || isStrncasecmp) { 412 413 /* check if third operand is a constant integer 414 * strlen("constStr") and sizeof() are treated as constant */ 415 Value *op2 = callInst->getArgOperand(2); 416 ConstantInt *ilen = dyn_cast<ConstantInt>(op2); 417 if (ilen) { 418 419 // if len is zero this is a pointless call but allow real 420 // implementation to worry about that 421 if (ilen->getZExtValue() < 2) { continue; } 422 423 } else if (isMemcmp) { 424 425 // this *may* supply a len greater than the constant string at 426 // runtime so similarly we don't want to have to handle that 427 continue; 428 429 } 430 431 } 432 433 calls.push_back(callInst); 434 435 } 436 437 } 438 439 } 440 441 } 442 443 if (!calls.size()) return false; 444 if (!be_quiet) 445 printf( 446 "Replacing %zu calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n", 447 calls.size()); 448 449 for (auto &callInst : calls) { 450 451 Value *Str1P = callInst->getArgOperand(0), 452 *Str2P = callInst->getArgOperand(1); 453 StringRef Str1, Str2, ConstStr; 454 std::string TmpConstStr; 455 Value *VarStr; 456 bool HasStr1 = getConstantStringInfo(Str1P, Str1); 457 bool HasStr2 = getConstantStringInfo(Str2P, Str2); 458 uint64_t constStrLen, unrollLen, constSizedLen = 0; 459 bool isMemcmp = false; 460 bool isSizedcmp = false; 461 bool isCaseInsensitive = false; 462 bool needs_null = false; 463 bool success_is_one = false; 464 Function *Callee = callInst->getCalledFunction(); 465 466 if (Callee) { 467 468 if (!Callee->getName().compare("memcmp") || 469 !Callee->getName().compare("bcmp") || 470 !Callee->getName().compare("CRYPTO_memcmp") || 471 !Callee->getName().compare("OPENSSL_memcmp") || 472 !Callee->getName().compare("memcmp_const_time") || 473 !Callee->getName().compare("memcmpct") || 474 !Callee->getName().compare("llvm.memcpy.p0i8.p0i8.i64")) 475 isMemcmp = true; 476 477 if (isMemcmp || !Callee->getName().compare("strncmp") || 478 !Callee->getName().compare("xmlStrncmp") || 479 !Callee->getName().compare("curl_strnequal") || 480 !Callee->getName().compare("strncasecmp") || 481 !Callee->getName().compare("strnicmp") || 482 !Callee->getName().compare("ap_cstr_casecmpn") || 483 !Callee->getName().compare("OPENSSL_strncasecmp") || 484 !Callee->getName().compare("xmlStrncasecmp") || 485 !Callee->getName().compare("g_ascii_strncasecmp") || 486 !Callee->getName().compare("Curl_strncasecompare") || 487 !Callee->getName().compare("g_strncasecmp")) 488 isSizedcmp = true; 489 490 if (!Callee->getName().compare("strcasecmp") || 491 !Callee->getName().compare("stricmp") || 492 !Callee->getName().compare("ap_cstr_casecmp") || 493 !Callee->getName().compare("OPENSSL_strcasecmp") || 494 !Callee->getName().compare("xmlStrcasecmp") || 495 !Callee->getName().compare("g_strcasecmp") || 496 !Callee->getName().compare("g_ascii_strcasecmp") || 497 !Callee->getName().compare("Curl_strcasecompare") || 498 !Callee->getName().compare("Curl_safe_strcasecompare") || 499 !Callee->getName().compare("cmsstrcasecmp") || 500 !Callee->getName().compare("strncasecmp") || 501 !Callee->getName().compare("strnicmp") || 502 !Callee->getName().compare("ap_cstr_casecmpn") || 503 !Callee->getName().compare("OPENSSL_strncasecmp") || 504 !Callee->getName().compare("xmlStrncasecmp") || 505 !Callee->getName().compare("g_ascii_strncasecmp") || 506 !Callee->getName().compare("Curl_strncasecompare") || 507 !Callee->getName().compare("g_strncasecmp")) 508 isCaseInsensitive = true; 509 510 if (!Callee->getName().compare("xmlStrEqual") || 511 !Callee->getName().compare("curl_strequal") || 512 !Callee->getName().compare("strcsequal") || 513 !Callee->getName().compare("curl_strnequal")) 514 success_is_one = true; 515 516 } 517 518 if (!isSizedcmp) needs_null = true; 519 520 Value *sizedValue = isSizedcmp ? callInst->getArgOperand(2) : NULL; 521 bool isConstSized = sizedValue && isa<ConstantInt>(sizedValue); 522 523 if (!(HasStr1 || HasStr2)) { 524 525 // do we have a saved local or global variable initialization? 526 std::string *val = valueMap[Str1P]; 527 if (val && !val->empty()) { 528 529 Str1 = StringRef(*val); 530 HasStr1 = true; 531 532 } else { 533 534 val = valueMap[Str2P]; 535 if (val && !val->empty()) { 536 537 Str2 = StringRef(*val); 538 // HasStr2 = true; 539 540 } 541 542 } 543 544 } 545 546 if (isConstSized) { 547 548 constSizedLen = dyn_cast<ConstantInt>(sizedValue)->getZExtValue(); 549 550 } 551 552 if (HasStr1) { 553 554 TmpConstStr = Str1.str(); 555 VarStr = Str2P; 556 557 } else { 558 559 TmpConstStr = Str2.str(); 560 VarStr = Str1P; 561 562 } 563 564 if (TmpConstStr.length() < 2 || 565 (TmpConstStr.length() == 2 && TmpConstStr[1] == 0)) { 566 567 continue; 568 569 } 570 571 // the following is in general OK, but strncmp is sometimes used in binary 572 // data structures and this can result in crashes :( so it is commented out 573 574 // add null termination character implicit in c strings 575 if (needs_null && TmpConstStr[TmpConstStr.length() - 1] != 0) { 576 577 TmpConstStr.append("\0", 1); 578 579 } 580 581 // in the unusual case the const str has embedded null 582 // characters, the string comparison functions should terminate 583 // at the first null 584 if (!isMemcmp && TmpConstStr.find('\0') != std::string::npos) { 585 586 TmpConstStr.assign(TmpConstStr, 0, TmpConstStr.find('\0') + 1); 587 588 } 589 590 constStrLen = TmpConstStr.length(); 591 // prefer use of StringRef (in comparison to std::string a StringRef has 592 // built-in runtime bounds checking, which makes debugging easier) 593 ConstStr = StringRef(TmpConstStr); 594 595 if (isConstSized) 596 unrollLen = constSizedLen < constStrLen ? constSizedLen : constStrLen; 597 else 598 unrollLen = constStrLen; 599 600 /* split before the call instruction */ 601 BasicBlock *bb = callInst->getParent(); 602 BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst)); 603 604 BasicBlock *next_lenchk_bb = NULL; 605 if (isSizedcmp && !isConstSized) { 606 607 next_lenchk_bb = 608 BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb); 609 BranchInst::Create(end_bb, next_lenchk_bb); 610 611 } 612 613 BasicBlock *next_cmp_bb = 614 BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); 615 BranchInst::Create(end_bb, next_cmp_bb); 616 PHINode *PN = PHINode::Create( 617 Int32Ty, (next_lenchk_bb ? 2 : 1) * unrollLen + 1, "cmp_phi"); 618 619 #if LLVM_VERSION_MAJOR >= 8 620 Instruction *term = bb->getTerminator(); 621 #else 622 TerminatorInst *term = bb->getTerminator(); 623 #endif 624 BranchInst::Create(next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, bb); 625 term->eraseFromParent(); 626 627 for (uint64_t i = 0; i < unrollLen; i++) { 628 629 BasicBlock *cur_cmp_bb = next_cmp_bb, *cur_lenchk_bb = next_lenchk_bb; 630 unsigned char c; 631 632 if (cur_lenchk_bb) { 633 634 IRBuilder<> cur_lenchk_IRB(&*(cur_lenchk_bb->getFirstInsertionPt())); 635 Value *icmp = cur_lenchk_IRB.CreateICmpEQ( 636 sizedValue, ConstantInt::get(sizedValue->getType(), i)); 637 cur_lenchk_IRB.CreateCondBr(icmp, end_bb, cur_cmp_bb); 638 cur_lenchk_bb->getTerminator()->eraseFromParent(); 639 640 PN->addIncoming(ConstantInt::get(Int32Ty, 0), cur_lenchk_bb); 641 642 } 643 644 if (isCaseInsensitive) 645 c = (unsigned char)(tolower((int)ConstStr[i]) & 0xff); 646 else 647 c = (unsigned char)ConstStr[i]; 648 649 IRBuilder<> cur_cmp_IRB(&*(cur_cmp_bb->getFirstInsertionPt())); 650 651 Value *v = ConstantInt::get(Int64Ty, i); 652 Value *ele = cur_cmp_IRB.CreateInBoundsGEP( 653 #if LLVM_VERSION_MAJOR >= 14 654 Int8Ty, 655 #endif 656 VarStr, v, "empty"); 657 Value *load = cur_cmp_IRB.CreateLoad( 658 #if LLVM_VERSION_MAJOR >= 14 659 Int8Ty, 660 #endif 661 ele); 662 663 if (isCaseInsensitive) { 664 665 // load >= 'A' && load <= 'Z' ? load | 0x020 : load 666 load = cur_cmp_IRB.CreateZExt(load, Int32Ty); 667 std::vector<Value *> args; 668 args.push_back(load); 669 load = cur_cmp_IRB.CreateCall(tolowerFn, args); 670 load = cur_cmp_IRB.CreateTrunc(load, Int8Ty); 671 672 } 673 674 Value *isub; 675 if (HasStr1) 676 isub = cur_cmp_IRB.CreateSub(ConstantInt::get(Int8Ty, c), load); 677 else 678 isub = cur_cmp_IRB.CreateSub(load, ConstantInt::get(Int8Ty, c)); 679 680 if (success_is_one && i == unrollLen - 1) { 681 682 Value *isubsub = cur_cmp_IRB.CreateTrunc(isub, Int1Ty); 683 isub = cur_cmp_IRB.CreateSelect(isubsub, ConstantInt::get(Int8Ty, 0), 684 ConstantInt::get(Int8Ty, 1)); 685 686 } 687 688 Value *sext = cur_cmp_IRB.CreateSExt(isub, Int32Ty); 689 PN->addIncoming(sext, cur_cmp_bb); 690 691 if (i < unrollLen - 1) { 692 693 if (cur_lenchk_bb) { 694 695 next_lenchk_bb = 696 BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb); 697 BranchInst::Create(end_bb, next_lenchk_bb); 698 699 } 700 701 next_cmp_bb = 702 BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); 703 BranchInst::Create(end_bb, next_cmp_bb); 704 705 Value *icmp = 706 cur_cmp_IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0)); 707 cur_cmp_IRB.CreateCondBr( 708 icmp, next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, end_bb); 709 cur_cmp_bb->getTerminator()->eraseFromParent(); 710 711 } else { 712 713 // IRB.CreateBr(end_bb); 714 715 } 716 717 // add offset to varstr 718 // create load 719 // create signed isub 720 // create icmp 721 // create jcc 722 // create next_bb 723 724 } 725 726 /* since the call is the first instruction of the bb it is safe to 727 * replace it with a phi instruction */ 728 BasicBlock::iterator ii(callInst); 729 #if LLVM_MAJOR >= 16 730 ReplaceInstWithInst(callInst->getParent(), ii, PN); 731 #else 732 ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN); 733 #endif 734 735 } 736 737 return true; 738 739 } 740 741 #if LLVM_MAJOR >= 11 /* use new pass manager */ 742 PreservedAnalyses CompareTransform::run(Module &M, ModuleAnalysisManager &MAM) { 743 744 #else 745 bool CompareTransform::runOnModule(Module &M) { 746 747 #endif 748 749 if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL) 750 printf( 751 "Running compare-transform-pass by [email protected], extended by " 752 "[email protected]\n"); 753 else 754 be_quiet = 1; 755 756 #if LLVM_MAJOR >= 11 /* use new pass manager */ 757 auto PA = PreservedAnalyses::all(); 758 #endif 759 760 transformCmps(M, true, true, true, true, true); 761 verifyModule(M); 762 763 #if LLVM_MAJOR >= 11 /* use new pass manager */ 764 /* if (modified) { 765 766 PA.abandon<XX_Manager>(); 767 768 }*/ 769 770 return PA; 771 #else 772 return true; 773 #endif 774 775 } 776 777 #if LLVM_MAJOR < 11 /* use old pass manager */ 778 static void registerCompTransPass(const PassManagerBuilder &, 779 legacy::PassManagerBase &PM) { 780 781 auto p = new CompareTransform(); 782 PM.add(p); 783 784 } 785 786 static RegisterStandardPasses RegisterCompTransPass( 787 PassManagerBuilder::EP_OptimizerLast, registerCompTransPass); 788 789 static RegisterStandardPasses RegisterCompTransPass0( 790 PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass); 791 792 #if LLVM_VERSION_MAJOR >= 11 793 static RegisterStandardPasses RegisterCompTransPassLTO( 794 PassManagerBuilder::EP_FullLinkTimeOptimizationLast, registerCompTransPass); 795 #endif 796 #endif 797 798