1 /* 2 * Copyright 2016 laf-intel 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * https://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <stdio.h> 18 #include <stdlib.h> 19 #include <unistd.h> 20 21 #include <list> 22 #include <string> 23 #include <fstream> 24 #include <sys/time.h> 25 26 #include "llvm/Config/llvm-config.h" 27 28 #include "llvm/ADT/Statistic.h" 29 #include "llvm/IR/IRBuilder.h" 30 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 31 #include "llvm/Passes/PassPlugin.h" 32 #include "llvm/Passes/PassBuilder.h" 33 #include "llvm/IR/PassManager.h" 34 #else 35 #include "llvm/IR/LegacyPassManager.h" 36 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 37 #endif 38 #include "llvm/IR/Module.h" 39 #include "llvm/Support/Debug.h" 40 #include "llvm/Support/raw_ostream.h" 41 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 42 #include "llvm/Pass.h" 43 #include "llvm/Analysis/ValueTracking.h" 44 #if LLVM_VERSION_MAJOR >= 14 /* how about stable interfaces? */ 45 #include "llvm/Passes/OptimizationLevel.h" 46 #endif 47 48 #include "llvm/IR/IRBuilder.h" 49 #if LLVM_VERSION_MAJOR >= 4 || \ 50 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 51 #include "llvm/IR/Verifier.h" 52 #include "llvm/IR/DebugInfo.h" 53 #else 54 #include "llvm/Analysis/Verifier.h" 55 #include "llvm/DebugInfo.h" 56 #define nullptr 0 57 #endif 58 59 #include <set> 60 #include "afl-llvm-common.h" 61 62 using namespace llvm; 63 64 namespace { 65 66 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 67 class SplitSwitchesTransform : public PassInfoMixin<SplitSwitchesTransform> { 68 69 public: SplitSwitchesTransform()70 SplitSwitchesTransform() { 71 72 #else 73 class SplitSwitchesTransform : public ModulePass { 74 75 public: 76 static char ID; 77 SplitSwitchesTransform() : ModulePass(ID) { 78 79 #endif 80 initInstrumentList(); 81 82 } 83 84 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 85 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 86 #else 87 bool runOnModule(Module &M) override; 88 89 #if LLVM_VERSION_MAJOR >= 4 90 StringRef getPassName() const override { 91 92 #else 93 const char *getPassName() const override { 94 95 #endif 96 return "splits switch constructs"; 97 98 } 99 100 #endif 101 102 struct CaseExpr { 103 104 ConstantInt *Val; 105 BasicBlock *BB; 106 107 CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) 108 : Val(val), BB(bb) { 109 110 } 111 112 }; 113 114 using CaseVector = std::vector<CaseExpr>; 115 116 private: 117 bool splitSwitches(Module &M); 118 bool transformCmps(Module &M, const bool processStrcmp, 119 const bool processMemcmp); 120 BasicBlock *switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, 121 BasicBlock *OrigBlock, BasicBlock *NewDefault, 122 Value *Val, unsigned level); 123 124 }; 125 126 } // namespace 127 128 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 129 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 130 llvmGetPassPluginInfo() { 131 132 return {LLVM_PLUGIN_API_VERSION, "splitswitches", "v0.1", 133 /* lambda to insert our pass into the pass pipeline. */ 134 [](PassBuilder &PB) { 135 136 #if 1 137 #if LLVM_VERSION_MAJOR <= 13 138 using OptimizationLevel = typename PassBuilder::OptimizationLevel; 139 #endif 140 PB.registerOptimizerLastEPCallback( 141 [](ModulePassManager &MPM, OptimizationLevel OL) { 142 143 MPM.addPass(SplitSwitchesTransform()); 144 145 }); 146 147 /* TODO LTO registration */ 148 #else 149 using PipelineElement = typename PassBuilder::PipelineElement; 150 PB.registerPipelineParsingCallback([](StringRef Name, 151 ModulePassManager &MPM, 152 ArrayRef<PipelineElement>) { 153 154 if (Name == "splitswitches") { 155 156 MPM.addPass(SplitSwitchesTransform()); 157 return true; 158 159 } else { 160 161 return false; 162 163 } 164 165 }); 166 167 #endif 168 169 }}; 170 171 } 172 173 #else 174 char SplitSwitchesTransform::ID = 0; 175 #endif 176 177 /* switchConvert - Transform simple list of Cases into list of CaseRange's */ 178 BasicBlock *SplitSwitchesTransform::switchConvert( 179 CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock, 180 BasicBlock *NewDefault, Value *Val, unsigned level) { 181 182 unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth(); 183 IntegerType *ValType = 184 IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth); 185 IntegerType *ByteType = IntegerType::get(OrigBlock->getContext(), 8); 186 unsigned BytesInValue = bytesChecked.size(); 187 std::vector<uint8_t> setSizes; 188 std::vector<std::set<uint8_t> > byteSets(BytesInValue, std::set<uint8_t>()); 189 190 /* for each of the possible cases we iterate over all bytes of the values 191 * build a set of possible values at each byte position in byteSets */ 192 for (CaseExpr &Case : Cases) { 193 194 for (unsigned i = 0; i < BytesInValue; i++) { 195 196 uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF; 197 byteSets[i].insert(byte); 198 199 } 200 201 } 202 203 /* find the index of the first byte position that was not yet checked. then 204 * save the number of possible values at that byte position */ 205 unsigned smallestIndex = 0; 206 unsigned smallestSize = 257; 207 for (unsigned i = 0; i < byteSets.size(); i++) { 208 209 if (bytesChecked[i]) continue; 210 if (byteSets[i].size() < smallestSize) { 211 212 smallestIndex = i; 213 smallestSize = byteSets[i].size(); 214 215 } 216 217 } 218 219 assert(bytesChecked[smallestIndex] == false); 220 221 /* there are only smallestSize different bytes at index smallestIndex */ 222 223 Instruction *Shift, *Trunc; 224 Function *F = OrigBlock->getParent(); 225 BasicBlock *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F); 226 Shift = BinaryOperator::Create(Instruction::LShr, Val, 227 ConstantInt::get(ValType, smallestIndex * 8)); 228 #if LLVM_VERSION_MAJOR >= 16 229 Shift->insertInto(NewNode, NewNode->end()); 230 #else 231 NewNode->getInstList().push_back(Shift); 232 #endif 233 234 if (ValTypeBitWidth > 8) { 235 236 Trunc = new TruncInst(Shift, ByteType); 237 #if LLVM_VERSION_MAJOR >= 16 238 Trunc->insertInto(NewNode, NewNode->end()); 239 #else 240 NewNode->getInstList().push_back(Trunc); 241 #endif 242 243 } else { 244 245 /* not necessary to trunc */ 246 Trunc = Shift; 247 248 } 249 250 /* this is a trivial case, we can directly check for the byte, 251 * if the byte is not found go to default. if the byte was found 252 * mark the byte as checked. if this was the last byte to check 253 * we can finally execute the block belonging to this case */ 254 255 if (smallestSize == 1) { 256 257 uint8_t byte = *(byteSets[smallestIndex].begin()); 258 259 /* insert instructions to check whether the value we are switching on is 260 * equal to byte */ 261 ICmpInst *Comp = 262 new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), 263 "byteMatch"); 264 #if LLVM_VERSION_MAJOR >= 16 265 Comp->insertInto(NewNode, NewNode->end()); 266 #else 267 NewNode->getInstList().push_back(Comp); 268 #endif 269 270 bytesChecked[smallestIndex] = true; 271 bool allBytesAreChecked = true; 272 273 for (std::vector<bool>::iterator BCI = bytesChecked.begin(), 274 E = bytesChecked.end(); 275 BCI != E; ++BCI) { 276 277 if (!*BCI) { 278 279 allBytesAreChecked = false; 280 break; 281 282 } 283 284 } 285 286 // if (std::all_of(bytesChecked.begin(), bytesChecked.end(), 287 // [](bool b) { return b; })) { 288 289 if (allBytesAreChecked) { 290 291 assert(Cases.size() == 1); 292 BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode); 293 294 /* we have to update the phi nodes! */ 295 for (BasicBlock::iterator I = Cases[0].BB->begin(); 296 I != Cases[0].BB->end(); ++I) { 297 298 if (!isa<PHINode>(&*I)) { continue; } 299 PHINode *PN = cast<PHINode>(I); 300 301 /* Only update the first occurrence. */ 302 unsigned Idx = 0, E = PN->getNumIncomingValues(); 303 for (; Idx != E; ++Idx) { 304 305 if (PN->getIncomingBlock(Idx) == OrigBlock) { 306 307 PN->setIncomingBlock(Idx, NewNode); 308 break; 309 310 } 311 312 } 313 314 } 315 316 } else { 317 318 BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, 319 Val, level + 1); 320 BranchInst::Create(BB, NewDefault, Comp, NewNode); 321 322 } 323 324 } 325 326 /* there is no byte which we can directly check on, split the tree */ 327 else { 328 329 std::vector<uint8_t> byteVector; 330 std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), 331 std::back_inserter(byteVector)); 332 std::sort(byteVector.begin(), byteVector.end()); 333 uint8_t pivot = byteVector[byteVector.size() / 2]; 334 335 /* we already chose to divide the cases based on the value of byte at index 336 * smallestIndex the pivot value determines the threshold for the decicion; 337 * if a case value 338 * is smaller at this byte index move it to the LHS vector, otherwise to the 339 * RHS vector */ 340 341 CaseVector LHSCases, RHSCases; 342 343 for (CaseExpr &Case : Cases) { 344 345 uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF; 346 347 if (byte < pivot) { 348 349 LHSCases.push_back(Case); 350 351 } else { 352 353 RHSCases.push_back(Case); 354 355 } 356 357 } 358 359 BasicBlock *LBB, *RBB; 360 LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, 361 level + 1); 362 RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, 363 level + 1); 364 365 /* insert instructions to check whether the value we are switching on is 366 * equal to byte */ 367 ICmpInst *Comp = 368 new ICmpInst(ICmpInst::ICMP_ULT, Trunc, 369 ConstantInt::get(ByteType, pivot), "byteMatch"); 370 #if LLVM_VERSION_MAJOR >= 16 371 Comp->insertInto(NewNode, NewNode->end()); 372 #else 373 NewNode->getInstList().push_back(Comp); 374 #endif 375 BranchInst::Create(LBB, RBB, Comp, NewNode); 376 377 } 378 379 return NewNode; 380 381 } 382 383 bool SplitSwitchesTransform::splitSwitches(Module &M) { 384 385 #if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 7) 386 LLVMContext &C = M.getContext(); 387 #endif 388 389 std::vector<SwitchInst *> switches; 390 391 /* iterate over all functions, bbs and instruction and add 392 * all switches to switches vector for later processing */ 393 for (auto &F : M) { 394 395 if (!isInInstrumentList(&F, MNAME)) continue; 396 397 for (auto &BB : F) { 398 399 SwitchInst *switchInst = nullptr; 400 401 if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { 402 403 if (switchInst->getNumCases() < 1) continue; 404 switches.push_back(switchInst); 405 406 } 407 408 } 409 410 } 411 412 if (!switches.size()) return false; 413 /* 414 if (!be_quiet) 415 errs() << "Rewriting " << switches.size() << " switch statements " 416 << "\n"; 417 */ 418 for (auto &SI : switches) { 419 420 BasicBlock *CurBlock = SI->getParent(); 421 BasicBlock *OrigBlock = CurBlock; 422 Function *F = CurBlock->getParent(); 423 /* this is the value we are switching on */ 424 Value *Val = SI->getCondition(); 425 BasicBlock *Default = SI->getDefaultDest(); 426 unsigned bitw = Val->getType()->getIntegerBitWidth(); 427 428 /* 429 if (!be_quiet) 430 errs() << "switch: " << SI->getNumCases() << " cases " << bitw 431 << " bit\n"; 432 */ 433 434 /* If there is only the default destination or the condition checks 8 bit or 435 * less, don't bother with the code below. */ 436 if (SI->getNumCases() < 2 || bitw % 8 || bitw > 64) { 437 438 // if (!be_quiet) errs() << "skip switch..\n"; 439 continue; 440 441 } 442 443 /* Create a new, empty default block so that the new hierarchy of 444 * if-then statements go to this and the PHI nodes are happy. 445 * if the default block is set as an unreachable we avoid creating one 446 * because will never be a valid target.*/ 447 BasicBlock *NewDefault = nullptr; 448 NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault", F, Default); 449 BranchInst::Create(Default, NewDefault); 450 451 /* Prepare cases vector. */ 452 CaseVector Cases; 453 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; 454 ++i) 455 #if LLVM_VERSION_MAJOR >= 5 456 Cases.push_back(CaseExpr(i->getCaseValue(), i->getCaseSuccessor())); 457 #else 458 Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor())); 459 #endif 460 /* bugfix thanks to pbst 461 * round up bytesChecked (in case getBitWidth() % 8 != 0) */ 462 std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, 463 false); 464 BasicBlock *SwitchBlock = 465 switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); 466 467 /* Branch to our shiny new if-then stuff... */ 468 BranchInst::Create(SwitchBlock, OrigBlock); 469 470 /* We are now done with the switch instruction, delete it. */ 471 #if LLVM_VERSION_MAJOR >= 16 472 SI->eraseFromParent(); 473 #else 474 CurBlock->getInstList().erase(SI); 475 #endif 476 477 /* we have to update the phi nodes! */ 478 for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) { 479 480 if (!isa<PHINode>(&*I)) { continue; } 481 PHINode *PN = cast<PHINode>(I); 482 483 /* Only update the first occurrence. */ 484 unsigned Idx = 0, E = PN->getNumIncomingValues(); 485 for (; Idx != E; ++Idx) { 486 487 if (PN->getIncomingBlock(Idx) == OrigBlock) { 488 489 PN->setIncomingBlock(Idx, NewDefault); 490 break; 491 492 } 493 494 } 495 496 } 497 498 } 499 500 verifyModule(M); 501 return true; 502 503 } 504 505 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 506 PreservedAnalyses SplitSwitchesTransform::run(Module &M, 507 ModuleAnalysisManager &MAM) { 508 509 #else 510 bool SplitSwitchesTransform::runOnModule(Module &M) { 511 512 #endif 513 514 if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL) 515 printf("Running split-switches-pass by [email protected]\n"); 516 else 517 be_quiet = 1; 518 519 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 520 auto PA = PreservedAnalyses::all(); 521 #endif 522 523 splitSwitches(M); 524 verifyModule(M); 525 526 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 527 /* if (modified) { 528 529 PA.abandon<XX_Manager>(); 530 531 }*/ 532 533 return PA; 534 #else 535 return true; 536 #endif 537 538 } 539 540 #if LLVM_VERSION_MAJOR < 11 /* use old pass manager */ 541 static void registerSplitSwitchesTransPass(const PassManagerBuilder &, 542 legacy::PassManagerBase &PM) { 543 544 auto p = new SplitSwitchesTransform(); 545 PM.add(p); 546 547 } 548 549 static RegisterStandardPasses RegisterSplitSwitchesTransPass( 550 PassManagerBuilder::EP_OptimizerLast, registerSplitSwitchesTransPass); 551 552 static RegisterStandardPasses RegisterSplitSwitchesTransPass0( 553 PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass); 554 555 #if LLVM_VERSION_MAJOR >= 11 556 static RegisterStandardPasses RegisterSplitSwitchesTransPassLTO( 557 PassManagerBuilder::EP_FullLinkTimeOptimizationLast, 558 registerSplitSwitchesTransPass); 559 #endif 560 #endif 561 562