xref: /aosp_15_r20/external/AFLplusplus/instrumentation/split-switches-pass.so.cc (revision 08b48e0b10e97b33e7b60c5b6e2243bd915777f2)
1 /*
2  * Copyright 2016 laf-intel
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <unistd.h>
20 
21 #include <list>
22 #include <string>
23 #include <fstream>
24 #include <sys/time.h>
25 
26 #include "llvm/Config/llvm-config.h"
27 
28 #include "llvm/ADT/Statistic.h"
29 #include "llvm/IR/IRBuilder.h"
30 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
31   #include "llvm/Passes/PassPlugin.h"
32   #include "llvm/Passes/PassBuilder.h"
33   #include "llvm/IR/PassManager.h"
34 #else
35   #include "llvm/IR/LegacyPassManager.h"
36   #include "llvm/Transforms/IPO/PassManagerBuilder.h"
37 #endif
38 #include "llvm/IR/Module.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/raw_ostream.h"
41 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
42 #include "llvm/Pass.h"
43 #include "llvm/Analysis/ValueTracking.h"
44 #if LLVM_VERSION_MAJOR >= 14                /* how about stable interfaces? */
45   #include "llvm/Passes/OptimizationLevel.h"
46 #endif
47 
48 #include "llvm/IR/IRBuilder.h"
49 #if LLVM_VERSION_MAJOR >= 4 || \
50     (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
51   #include "llvm/IR/Verifier.h"
52   #include "llvm/IR/DebugInfo.h"
53 #else
54   #include "llvm/Analysis/Verifier.h"
55   #include "llvm/DebugInfo.h"
56   #define nullptr 0
57 #endif
58 
59 #include <set>
60 #include "afl-llvm-common.h"
61 
62 using namespace llvm;
63 
64 namespace {
65 
66 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
67 class SplitSwitchesTransform : public PassInfoMixin<SplitSwitchesTransform> {
68 
69  public:
SplitSwitchesTransform()70   SplitSwitchesTransform() {
71 
72 #else
73 class SplitSwitchesTransform : public ModulePass {
74 
75  public:
76   static char ID;
77   SplitSwitchesTransform() : ModulePass(ID) {
78 
79 #endif
80     initInstrumentList();
81 
82   }
83 
84 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
85   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
86 #else
87   bool runOnModule(Module &M) override;
88 
89   #if LLVM_VERSION_MAJOR >= 4
90   StringRef getPassName() const override {
91 
92   #else
93   const char *getPassName() const override {
94 
95   #endif
96     return "splits switch constructs";
97 
98   }
99 
100 #endif
101 
102   struct CaseExpr {
103 
104     ConstantInt *Val;
105     BasicBlock  *BB;
106 
107     CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr)
108         : Val(val), BB(bb) {
109 
110     }
111 
112   };
113 
114   using CaseVector = std::vector<CaseExpr>;
115 
116  private:
117   bool        splitSwitches(Module &M);
118   bool        transformCmps(Module &M, const bool processStrcmp,
119                             const bool processMemcmp);
120   BasicBlock *switchConvert(CaseVector Cases, std::vector<bool> bytesChecked,
121                             BasicBlock *OrigBlock, BasicBlock *NewDefault,
122                             Value *Val, unsigned level);
123 
124 };
125 
126 }  // namespace
127 
128 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
129 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
130 llvmGetPassPluginInfo() {
131 
132   return {LLVM_PLUGIN_API_VERSION, "splitswitches", "v0.1",
133           /* lambda to insert our pass into the pass pipeline. */
134           [](PassBuilder &PB) {
135 
136   #if 1
137     #if LLVM_VERSION_MAJOR <= 13
138             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
139     #endif
140             PB.registerOptimizerLastEPCallback(
141                 [](ModulePassManager &MPM, OptimizationLevel OL) {
142 
143                   MPM.addPass(SplitSwitchesTransform());
144 
145                 });
146 
147   /* TODO LTO registration */
148   #else
149             using PipelineElement = typename PassBuilder::PipelineElement;
150             PB.registerPipelineParsingCallback([](StringRef          Name,
151                                                   ModulePassManager &MPM,
152                                                   ArrayRef<PipelineElement>) {
153 
154               if (Name == "splitswitches") {
155 
156                 MPM.addPass(SplitSwitchesTransform());
157                 return true;
158 
159               } else {
160 
161                 return false;
162 
163               }
164 
165             });
166 
167   #endif
168 
169           }};
170 
171 }
172 
173 #else
174 char SplitSwitchesTransform::ID = 0;
175 #endif
176 
177 /* switchConvert - Transform simple list of Cases into list of CaseRange's */
178 BasicBlock *SplitSwitchesTransform::switchConvert(
179     CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock,
180     BasicBlock *NewDefault, Value *Val, unsigned level) {
181 
182   unsigned     ValTypeBitWidth = Cases[0].Val->getBitWidth();
183   IntegerType *ValType =
184       IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth);
185   IntegerType         *ByteType = IntegerType::get(OrigBlock->getContext(), 8);
186   unsigned             BytesInValue = bytesChecked.size();
187   std::vector<uint8_t> setSizes;
188   std::vector<std::set<uint8_t> > byteSets(BytesInValue, std::set<uint8_t>());
189 
190   /* for each of the possible cases we iterate over all bytes of the values
191    * build a set of possible values at each byte position in byteSets */
192   for (CaseExpr &Case : Cases) {
193 
194     for (unsigned i = 0; i < BytesInValue; i++) {
195 
196       uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF;
197       byteSets[i].insert(byte);
198 
199     }
200 
201   }
202 
203   /* find the index of the first byte position that was not yet checked. then
204    * save the number of possible values at that byte position */
205   unsigned smallestIndex = 0;
206   unsigned smallestSize = 257;
207   for (unsigned i = 0; i < byteSets.size(); i++) {
208 
209     if (bytesChecked[i]) continue;
210     if (byteSets[i].size() < smallestSize) {
211 
212       smallestIndex = i;
213       smallestSize = byteSets[i].size();
214 
215     }
216 
217   }
218 
219   assert(bytesChecked[smallestIndex] == false);
220 
221   /* there are only smallestSize different bytes at index smallestIndex */
222 
223   Instruction *Shift, *Trunc;
224   Function    *F = OrigBlock->getParent();
225   BasicBlock  *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F);
226   Shift = BinaryOperator::Create(Instruction::LShr, Val,
227                                  ConstantInt::get(ValType, smallestIndex * 8));
228 #if LLVM_VERSION_MAJOR >= 16
229   Shift->insertInto(NewNode, NewNode->end());
230 #else
231   NewNode->getInstList().push_back(Shift);
232 #endif
233 
234   if (ValTypeBitWidth > 8) {
235 
236     Trunc = new TruncInst(Shift, ByteType);
237 #if LLVM_VERSION_MAJOR >= 16
238     Trunc->insertInto(NewNode, NewNode->end());
239 #else
240     NewNode->getInstList().push_back(Trunc);
241 #endif
242 
243   } else {
244 
245     /* not necessary to trunc */
246     Trunc = Shift;
247 
248   }
249 
250   /* this is a trivial case, we can directly check for the byte,
251    * if the byte is not found go to default. if the byte was found
252    * mark the byte as checked. if this was the last byte to check
253    * we can finally execute the block belonging to this case */
254 
255   if (smallestSize == 1) {
256 
257     uint8_t byte = *(byteSets[smallestIndex].begin());
258 
259     /* insert instructions to check whether the value we are switching on is
260      * equal to byte */
261     ICmpInst *Comp =
262         new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte),
263                      "byteMatch");
264 #if LLVM_VERSION_MAJOR >= 16
265     Comp->insertInto(NewNode, NewNode->end());
266 #else
267     NewNode->getInstList().push_back(Comp);
268 #endif
269 
270     bytesChecked[smallestIndex] = true;
271     bool allBytesAreChecked = true;
272 
273     for (std::vector<bool>::iterator BCI = bytesChecked.begin(),
274                                      E = bytesChecked.end();
275          BCI != E; ++BCI) {
276 
277       if (!*BCI) {
278 
279         allBytesAreChecked = false;
280         break;
281 
282       }
283 
284     }
285 
286     //    if (std::all_of(bytesChecked.begin(), bytesChecked.end(),
287     //                    [](bool b) { return b; })) {
288 
289     if (allBytesAreChecked) {
290 
291       assert(Cases.size() == 1);
292       BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode);
293 
294       /* we have to update the phi nodes! */
295       for (BasicBlock::iterator I = Cases[0].BB->begin();
296            I != Cases[0].BB->end(); ++I) {
297 
298         if (!isa<PHINode>(&*I)) { continue; }
299         PHINode *PN = cast<PHINode>(I);
300 
301         /* Only update the first occurrence. */
302         unsigned Idx = 0, E = PN->getNumIncomingValues();
303         for (; Idx != E; ++Idx) {
304 
305           if (PN->getIncomingBlock(Idx) == OrigBlock) {
306 
307             PN->setIncomingBlock(Idx, NewNode);
308             break;
309 
310           }
311 
312         }
313 
314       }
315 
316     } else {
317 
318       BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault,
319                                      Val, level + 1);
320       BranchInst::Create(BB, NewDefault, Comp, NewNode);
321 
322     }
323 
324   }
325 
326   /* there is no byte which we can directly check on, split the tree */
327   else {
328 
329     std::vector<uint8_t> byteVector;
330     std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(),
331               std::back_inserter(byteVector));
332     std::sort(byteVector.begin(), byteVector.end());
333     uint8_t pivot = byteVector[byteVector.size() / 2];
334 
335     /* we already chose to divide the cases based on the value of byte at index
336      * smallestIndex the pivot value determines the threshold for the decicion;
337      * if a case value
338      * is smaller at this byte index move it to the LHS vector, otherwise to the
339      * RHS vector */
340 
341     CaseVector LHSCases, RHSCases;
342 
343     for (CaseExpr &Case : Cases) {
344 
345       uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF;
346 
347       if (byte < pivot) {
348 
349         LHSCases.push_back(Case);
350 
351       } else {
352 
353         RHSCases.push_back(Case);
354 
355       }
356 
357     }
358 
359     BasicBlock *LBB, *RBB;
360     LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val,
361                         level + 1);
362     RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val,
363                         level + 1);
364 
365     /* insert instructions to check whether the value we are switching on is
366      * equal to byte */
367     ICmpInst *Comp =
368         new ICmpInst(ICmpInst::ICMP_ULT, Trunc,
369                      ConstantInt::get(ByteType, pivot), "byteMatch");
370 #if LLVM_VERSION_MAJOR >= 16
371     Comp->insertInto(NewNode, NewNode->end());
372 #else
373     NewNode->getInstList().push_back(Comp);
374 #endif
375     BranchInst::Create(LBB, RBB, Comp, NewNode);
376 
377   }
378 
379   return NewNode;
380 
381 }
382 
383 bool SplitSwitchesTransform::splitSwitches(Module &M) {
384 
385 #if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 7)
386   LLVMContext &C = M.getContext();
387 #endif
388 
389   std::vector<SwitchInst *> switches;
390 
391   /* iterate over all functions, bbs and instruction and add
392    * all switches to switches vector for later processing */
393   for (auto &F : M) {
394 
395     if (!isInInstrumentList(&F, MNAME)) continue;
396 
397     for (auto &BB : F) {
398 
399       SwitchInst *switchInst = nullptr;
400 
401       if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
402 
403         if (switchInst->getNumCases() < 1) continue;
404         switches.push_back(switchInst);
405 
406       }
407 
408     }
409 
410   }
411 
412   if (!switches.size()) return false;
413   /*
414     if (!be_quiet)
415       errs() << "Rewriting " << switches.size() << " switch statements "
416              << "\n";
417   */
418   for (auto &SI : switches) {
419 
420     BasicBlock *CurBlock = SI->getParent();
421     BasicBlock *OrigBlock = CurBlock;
422     Function   *F = CurBlock->getParent();
423     /* this is the value we are switching on */
424     Value      *Val = SI->getCondition();
425     BasicBlock *Default = SI->getDefaultDest();
426     unsigned    bitw = Val->getType()->getIntegerBitWidth();
427 
428     /*
429         if (!be_quiet)
430           errs() << "switch: " << SI->getNumCases() << " cases " << bitw
431                  << " bit\n";
432     */
433 
434     /* If there is only the default destination or the condition checks 8 bit or
435      * less, don't bother with the code below. */
436     if (SI->getNumCases() < 2 || bitw % 8 || bitw > 64) {
437 
438       // if (!be_quiet) errs() << "skip switch..\n";
439       continue;
440 
441     }
442 
443     /* Create a new, empty default block so that the new hierarchy of
444      * if-then statements go to this and the PHI nodes are happy.
445      * if the default block is set as an unreachable we avoid creating one
446      * because will never be a valid target.*/
447     BasicBlock *NewDefault = nullptr;
448     NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault", F, Default);
449     BranchInst::Create(Default, NewDefault);
450 
451     /* Prepare cases vector. */
452     CaseVector Cases;
453     for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
454          ++i)
455 #if LLVM_VERSION_MAJOR >= 5
456       Cases.push_back(CaseExpr(i->getCaseValue(), i->getCaseSuccessor()));
457 #else
458       Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor()));
459 #endif
460     /* bugfix thanks to pbst
461      * round up bytesChecked (in case getBitWidth() % 8 != 0) */
462     std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8,
463                                    false);
464     BasicBlock       *SwitchBlock =
465         switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0);
466 
467     /* Branch to our shiny new if-then stuff... */
468     BranchInst::Create(SwitchBlock, OrigBlock);
469 
470     /* We are now done with the switch instruction, delete it. */
471 #if LLVM_VERSION_MAJOR >= 16
472     SI->eraseFromParent();
473 #else
474     CurBlock->getInstList().erase(SI);
475 #endif
476 
477     /* we have to update the phi nodes! */
478     for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) {
479 
480       if (!isa<PHINode>(&*I)) { continue; }
481       PHINode *PN = cast<PHINode>(I);
482 
483       /* Only update the first occurrence. */
484       unsigned Idx = 0, E = PN->getNumIncomingValues();
485       for (; Idx != E; ++Idx) {
486 
487         if (PN->getIncomingBlock(Idx) == OrigBlock) {
488 
489           PN->setIncomingBlock(Idx, NewDefault);
490           break;
491 
492         }
493 
494       }
495 
496     }
497 
498   }
499 
500   verifyModule(M);
501   return true;
502 
503 }
504 
505 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
506 PreservedAnalyses SplitSwitchesTransform::run(Module                &M,
507                                               ModuleAnalysisManager &MAM) {
508 
509 #else
510 bool SplitSwitchesTransform::runOnModule(Module &M) {
511 
512 #endif
513 
514   if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL)
515     printf("Running split-switches-pass by [email protected]\n");
516   else
517     be_quiet = 1;
518 
519 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
520   auto PA = PreservedAnalyses::all();
521 #endif
522 
523   splitSwitches(M);
524   verifyModule(M);
525 
526 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
527                              /*  if (modified) {
528 
529                                  PA.abandon<XX_Manager>();
530 
531                                }*/
532 
533   return PA;
534 #else
535   return true;
536 #endif
537 
538 }
539 
540 #if LLVM_VERSION_MAJOR < 11                         /* use old pass manager */
541 static void registerSplitSwitchesTransPass(const PassManagerBuilder &,
542                                            legacy::PassManagerBase &PM) {
543 
544   auto p = new SplitSwitchesTransform();
545   PM.add(p);
546 
547 }
548 
549 static RegisterStandardPasses RegisterSplitSwitchesTransPass(
550     PassManagerBuilder::EP_OptimizerLast, registerSplitSwitchesTransPass);
551 
552 static RegisterStandardPasses RegisterSplitSwitchesTransPass0(
553     PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass);
554 
555   #if LLVM_VERSION_MAJOR >= 11
556 static RegisterStandardPasses RegisterSplitSwitchesTransPassLTO(
557     PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
558     registerSplitSwitchesTransPass);
559   #endif
560 #endif
561 
562