1 /*
2 american fuzzy lop++ - LLVM CmpLog instrumentation
3 --------------------------------------------------
4
5 Written by Andrea Fioraldi <[email protected]>
6
7 Copyright 2015, 2016 Google Inc. All rights reserved.
8 Copyright 2019-2024 AFLplusplus Project. All rights reserved.
9
10 Licensed under the Apache License, Version 2.0 (the "License");
11 you may not use this file except in compliance with the License.
12 You may obtain a copy of the License at:
13
14 https://www.apache.org/licenses/LICENSE-2.0
15
16 */
17
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21
22 #include <iostream>
23 #include <list>
24 #include <string>
25 #include <fstream>
26 #include <sys/time.h>
27
28 #include "llvm/Config/llvm-config.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #if LLVM_MAJOR >= 11
35 #include "llvm/Passes/PassPlugin.h"
36 #include "llvm/Passes/PassBuilder.h"
37 #include "llvm/IR/PassManager.h"
38 #else
39 #include "llvm/IR/LegacyPassManager.h"
40 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
41 #endif
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Analysis/ValueTracking.h"
45 #if LLVM_VERSION_MAJOR >= 14 /* how about stable interfaces? */
46 #include "llvm/Passes/OptimizationLevel.h"
47 #endif
48
49 #if LLVM_VERSION_MAJOR >= 4 || \
50 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
51 #include "llvm/IR/Verifier.h"
52 #include "llvm/IR/DebugInfo.h"
53 #include "llvm/Support/raw_ostream.h"
54 #else
55 #include "llvm/Analysis/Verifier.h"
56 #include "llvm/DebugInfo.h"
57 #define nullptr 0
58 #endif
59
60 #include <set>
61 #include "afl-llvm-common.h"
62
63 using namespace llvm;
64
65 namespace {
66
67 #if LLVM_MAJOR >= 11 /* use new pass manager */
68 class CmpLogInstructions : public PassInfoMixin<CmpLogInstructions> {
69
70 public:
CmpLogInstructions()71 CmpLogInstructions() {
72
73 initInstrumentList();
74
75 }
76
77 #else
78 class CmpLogInstructions : public ModulePass {
79
80 public:
81 static char ID;
82 CmpLogInstructions() : ModulePass(ID) {
83
84 initInstrumentList();
85
86 }
87
88 #endif
89
90 #if LLVM_MAJOR >= 11 /* use new pass manager */
91 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
92 #else
93 bool runOnModule(Module &M) override;
94
95 #if LLVM_VERSION_MAJOR >= 4
getPassName() const96 StringRef getPassName() const override {
97
98 #else
99 const char *getPassName() const override {
100
101 #endif
102 return "cmplog instructions";
103
104 }
105
106 #endif
107
108 private:
109 bool hookInstrs(Module &M);
110
111 };
112
113 } // namespace
114
115 #if LLVM_MAJOR >= 11
116 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo()117 llvmGetPassPluginInfo() {
118
119 return {LLVM_PLUGIN_API_VERSION, "cmploginstructions", "v0.1",
120 /* lambda to insert our pass into the pass pipeline. */
121 [](PassBuilder &PB) {
122
123 #if LLVM_VERSION_MAJOR <= 13
124 using OptimizationLevel = typename PassBuilder::OptimizationLevel;
125 #endif
126 PB.registerOptimizerLastEPCallback(
127 [](ModulePassManager &MPM, OptimizationLevel OL) {
128
129 MPM.addPass(CmpLogInstructions());
130
131 });
132
133 }};
134
135 }
136
137 #else
138 char CmpLogInstructions::ID = 0;
139 #endif
140
141 template <class Iterator>
Unique(Iterator first,Iterator last)142 Iterator Unique(Iterator first, Iterator last) {
143
144 while (first != last) {
145
146 Iterator next(first);
147 last = std::remove(++next, last, *first);
148 first = next;
149
150 }
151
152 return last;
153
154 }
155
hookInstrs(Module & M)156 bool CmpLogInstructions::hookInstrs(Module &M) {
157
158 std::vector<Instruction *> icomps;
159 LLVMContext &C = M.getContext();
160
161 Type *VoidTy = Type::getVoidTy(C);
162 IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
163 IntegerType *Int16Ty = IntegerType::getInt16Ty(C);
164 IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
165 IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
166 IntegerType *Int128Ty = IntegerType::getInt128Ty(C);
167
168 /*
169 #if LLVM_VERSION_MAJOR >= 9
170 FunctionCallee
171 #else
172 Constant *
173 #endif
174 c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
175 Int8Ty
176 #if LLVM_VERSION_MAJOR < 5
177 ,
178 NULL
179 #endif
180 );
181 #if LLVM_VERSION_MAJOR >= 9
182 FunctionCallee cmplogHookIns1 = c1;
183 #else
184 Function *cmplogHookIns1 = cast<Function>(c1);
185 #endif
186 */
187
188 #if LLVM_VERSION_MAJOR >= 9
189 FunctionCallee
190 #else
191 Constant *
192 #endif
193 c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty,
194 Int8Ty
195 #if LLVM_VERSION_MAJOR < 5
196 ,
197 NULL
198 #endif
199 );
200 #if LLVM_VERSION_MAJOR >= 9
201 FunctionCallee cmplogHookIns2 = c2;
202 #else
203 Function *cmplogHookIns2 = cast<Function>(c2);
204 #endif
205
206 #if LLVM_VERSION_MAJOR >= 9
207 FunctionCallee
208 #else
209 Constant *
210 #endif
211 c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty,
212 Int8Ty
213 #if LLVM_VERSION_MAJOR < 5
214 ,
215 NULL
216 #endif
217 );
218 #if LLVM_VERSION_MAJOR >= 9
219 FunctionCallee cmplogHookIns4 = c4;
220 #else
221 Function *cmplogHookIns4 = cast<Function>(c4);
222 #endif
223
224 #if LLVM_VERSION_MAJOR >= 9
225 FunctionCallee
226 #else
227 Constant *
228 #endif
229 c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty,
230 Int8Ty
231 #if LLVM_VERSION_MAJOR < 5
232 ,
233 NULL
234 #endif
235 );
236 #if LLVM_VERSION_MAJOR >= 9
237 FunctionCallee cmplogHookIns8 = c8;
238 #else
239 Function *cmplogHookIns8 = cast<Function>(c8);
240 #endif
241
242 #if LLVM_VERSION_MAJOR >= 9
243 FunctionCallee
244 #else
245 Constant *
246 #endif
247 c16 = M.getOrInsertFunction("__cmplog_ins_hook16", VoidTy, Int128Ty,
248 Int128Ty, Int8Ty
249 #if LLVM_VERSION_MAJOR < 5
250 ,
251 NULL
252 #endif
253 );
254 #if LLVM_VERSION_MAJOR < 9
255 Function *cmplogHookIns16 = cast<Function>(c16);
256 #else
257 FunctionCallee cmplogHookIns16 = c16;
258 #endif
259
260 #if LLVM_VERSION_MAJOR >= 9
261 FunctionCallee
262 #else
263 Constant *
264 #endif
265 cN = M.getOrInsertFunction("__cmplog_ins_hookN", VoidTy, Int128Ty,
266 Int128Ty, Int8Ty, Int8Ty
267 #if LLVM_VERSION_MAJOR < 5
268 ,
269 NULL
270 #endif
271 );
272 #if LLVM_VERSION_MAJOR >= 9
273 FunctionCallee cmplogHookInsN = cN;
274 #else
275 Function *cmplogHookInsN = cast<Function>(cN);
276 #endif
277
278 GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
279
280 if (!AFLCmplogPtr) {
281
282 AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
283 GlobalValue::ExternalWeakLinkage, 0,
284 "__afl_cmp_map");
285
286 }
287
288 Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
289
290 /* iterate over all functions, bbs and instruction and add suitable calls */
291 for (auto &F : M) {
292
293 if (!isInInstrumentList(&F, MNAME)) continue;
294
295 for (auto &BB : F) {
296
297 for (auto &IN : BB) {
298
299 CmpInst *selectcmpInst = nullptr;
300 if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
301
302 icomps.push_back(selectcmpInst);
303
304 }
305
306 }
307
308 }
309
310 }
311
312 if (icomps.size()) {
313
314 // if (!be_quiet) errs() << "Hooking " << icomps.size() <<
315 // " cmp instructions\n";
316
317 for (auto &selectcmpInst : icomps) {
318
319 IRBuilder<> IRB2(selectcmpInst->getParent());
320 IRB2.SetInsertPoint(selectcmpInst);
321 LoadInst *CmpPtr = IRB2.CreateLoad(
322 #if LLVM_VERSION_MAJOR >= 14
323 PointerType::get(Int8Ty, 0),
324 #endif
325 AFLCmplogPtr);
326 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
327 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
328 auto ThenTerm =
329 SplitBlockAndInsertIfThen(is_not_null, selectcmpInst, false);
330
331 IRBuilder<> IRB(ThenTerm);
332
333 Value *op0 = selectcmpInst->getOperand(0);
334 Value *op1 = selectcmpInst->getOperand(1);
335 Value *op0_saved = op0, *op1_saved = op1;
336 auto ty0 = op0->getType();
337 auto ty1 = op1->getType();
338
339 IntegerType *intTyOp0 = NULL;
340 IntegerType *intTyOp1 = NULL;
341 unsigned max_size = 0, cast_size = 0;
342 unsigned attr = 0, vector_cnt = 0, is_fp = 0;
343 CmpInst *cmpInst = dyn_cast<CmpInst>(selectcmpInst);
344
345 if (!cmpInst) { continue; }
346
347 switch (cmpInst->getPredicate()) {
348
349 case CmpInst::ICMP_NE:
350 case CmpInst::FCMP_UNE:
351 case CmpInst::FCMP_ONE:
352 break;
353 case CmpInst::ICMP_EQ:
354 case CmpInst::FCMP_UEQ:
355 case CmpInst::FCMP_OEQ:
356 attr += 1;
357 break;
358 case CmpInst::ICMP_UGT:
359 case CmpInst::ICMP_SGT:
360 case CmpInst::FCMP_OGT:
361 case CmpInst::FCMP_UGT:
362 attr += 2;
363 break;
364 case CmpInst::ICMP_UGE:
365 case CmpInst::ICMP_SGE:
366 case CmpInst::FCMP_OGE:
367 case CmpInst::FCMP_UGE:
368 attr += 3;
369 break;
370 case CmpInst::ICMP_ULT:
371 case CmpInst::ICMP_SLT:
372 case CmpInst::FCMP_OLT:
373 case CmpInst::FCMP_ULT:
374 attr += 4;
375 break;
376 case CmpInst::ICMP_ULE:
377 case CmpInst::ICMP_SLE:
378 case CmpInst::FCMP_OLE:
379 case CmpInst::FCMP_ULE:
380 attr += 5;
381 break;
382 default:
383 break;
384
385 }
386
387 if (selectcmpInst->getOpcode() == Instruction::FCmp) {
388
389 if (ty0->isVectorTy()) {
390
391 VectorType *tt = dyn_cast<VectorType>(ty0);
392 if (!tt) {
393
394 fprintf(stderr, "Warning: cmplog cmp vector is not a vector!\n");
395 continue;
396
397 }
398
399 #if (LLVM_VERSION_MAJOR >= 12)
400 vector_cnt = tt->getElementCount().getKnownMinValue();
401 ty0 = tt->getElementType();
402 #endif
403
404 }
405
406 if (ty0->isHalfTy()
407 #if LLVM_VERSION_MAJOR >= 11
408 || ty0->isBFloatTy()
409 #endif
410 )
411 max_size = 16;
412 else if (ty0->isFloatTy())
413 max_size = 32;
414 else if (ty0->isDoubleTy())
415 max_size = 64;
416 else if (ty0->isX86_FP80Ty())
417 max_size = 80;
418 else if (ty0->isFP128Ty() || ty0->isPPC_FP128Ty())
419 max_size = 128;
420 #if (LLVM_VERSION_MAJOR >= 12)
421 else if (ty0->getTypeID() != llvm::Type::PointerTyID && !be_quiet)
422 fprintf(stderr, "Warning: unsupported cmp type for cmplog: %u!\n",
423 ty0->getTypeID());
424 #endif
425
426 attr += 8;
427 is_fp = 1;
428 // fprintf(stderr, "HAVE FP %u!\n", vector_cnt);
429
430 } else {
431
432 if (ty0->isVectorTy()) {
433
434 #if (LLVM_VERSION_MAJOR >= 12)
435 VectorType *tt = dyn_cast<VectorType>(ty0);
436 if (!tt) {
437
438 fprintf(stderr, "Warning: cmplog cmp vector is not a vector!\n");
439 continue;
440
441 }
442
443 vector_cnt = tt->getElementCount().getKnownMinValue();
444 ty1 = ty0 = tt->getElementType();
445 #endif
446
447 }
448
449 intTyOp0 = dyn_cast<IntegerType>(ty0);
450 intTyOp1 = dyn_cast<IntegerType>(ty1);
451
452 if (intTyOp0 && intTyOp1) {
453
454 max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
455 ? intTyOp0->getBitWidth()
456 : intTyOp1->getBitWidth();
457
458 } else {
459
460 #if (LLVM_VERSION_MAJOR >= 12)
461 if (ty0->getTypeID() != llvm::Type::PointerTyID && !be_quiet) {
462
463 fprintf(stderr, "Warning: unsupported cmp type for cmplog: %u\n",
464 ty0->getTypeID());
465
466 }
467
468 #endif
469
470 }
471
472 }
473
474 if (!max_size || max_size < 16) {
475
476 // fprintf(stderr, "too small\n");
477 continue;
478
479 }
480
481 if (max_size % 8) { max_size = (((max_size / 8) + 1) * 8); }
482
483 if (max_size > 128) {
484
485 if (!be_quiet) {
486
487 fprintf(stderr,
488 "Cannot handle this compare bit size: %u (truncating)\n",
489 max_size);
490
491 }
492
493 max_size = 128;
494
495 }
496
497 // do we need to cast?
498 switch (max_size) {
499
500 case 8:
501 case 16:
502 case 32:
503 case 64:
504 case 128:
505 cast_size = max_size;
506 break;
507 default:
508 cast_size = 128;
509
510 }
511
512 // XXX FIXME BUG TODO
513 if (is_fp && vector_cnt) { continue; }
514
515 uint64_t cur = 0, last_val0 = 0, last_val1 = 0, cur_val;
516
517 while (1) {
518
519 std::vector<Value *> args;
520 bool skip = false;
521
522 if (vector_cnt) {
523
524 op0 = IRB.CreateExtractElement(op0_saved, cur);
525 op1 = IRB.CreateExtractElement(op1_saved, cur);
526 /*
527 std::string errMsg;
528 raw_string_ostream os(errMsg);
529 op0_saved->print(os);
530 fprintf(stderr, "X: %s\n", os.str().c_str());
531 */
532 if (is_fp) {
533
534 /*
535 ConstantFP *i0 = dyn_cast<ConstantFP>(op0);
536 ConstantFP *i1 = dyn_cast<ConstantFP>(op1);
537 // BUG FIXME TODO: this is null ... but why?
538 // fprintf(stderr, "%p %p\n", i0, i1);
539 if (i0) {
540
541 cur_val = (uint64_t)i0->getValue().convertToDouble();
542 if (last_val0 && last_val0 == cur_val) { skip = true;
543
544 } last_val0 = cur_val;
545
546 }
547
548 if (i1) {
549
550 cur_val = (uint64_t)i1->getValue().convertToDouble();
551 if (last_val1 && last_val1 == cur_val) { skip = true;
552
553 } last_val1 = cur_val;
554
555 }
556
557 */
558
559 } else {
560
561 ConstantInt *i0 = dyn_cast<ConstantInt>(op0);
562 ConstantInt *i1 = dyn_cast<ConstantInt>(op1);
563 if (i0 && i0->uge(0xffffffffffffffff) == false) {
564
565 cur_val = i0->getZExtValue();
566 if (last_val0 && last_val0 == cur_val) { skip = true; }
567 last_val0 = cur_val;
568
569 }
570
571 if (i1 && i1->uge(0xffffffffffffffff) == false) {
572
573 cur_val = i1->getZExtValue();
574 if (last_val1 && last_val1 == cur_val) { skip = true; }
575 last_val1 = cur_val;
576
577 }
578
579 }
580
581 }
582
583 if (!skip) {
584
585 // errs() << "[CMPLOG] cmp " << *cmpInst << "(in function " <<
586 // cmpInst->getFunction()->getName() << ")\n";
587
588 // first bitcast to integer type of the same bitsize as the original
589 // type (this is a nop, if already integer)
590 Value *op0_i = IRB.CreateBitCast(
591 op0, IntegerType::get(C, ty0->getPrimitiveSizeInBits()));
592 // then create a int cast, which does zext, trunc or bitcast. In our
593 // case usually zext to the next larger supported type (this is a nop
594 // if already the right type)
595 Value *V0 =
596 IRB.CreateIntCast(op0_i, IntegerType::get(C, cast_size), false);
597 args.push_back(V0);
598 Value *op1_i = IRB.CreateBitCast(
599 op1, IntegerType::get(C, ty1->getPrimitiveSizeInBits()));
600 Value *V1 =
601 IRB.CreateIntCast(op1_i, IntegerType::get(C, cast_size), false);
602 args.push_back(V1);
603
604 // errs() << "[CMPLOG] casted parameters:\n0: " << *V0 << "\n1: " <<
605 // *V1
606 // << "\n";
607
608 ConstantInt *attribute = ConstantInt::get(Int8Ty, attr);
609 args.push_back(attribute);
610
611 if (cast_size != max_size) {
612
613 ConstantInt *bitsize = ConstantInt::get(Int8Ty, (max_size / 8) - 1);
614 args.push_back(bitsize);
615
616 }
617
618 // fprintf(stderr, "_ExtInt(%u) castTo %u with attr %u didcast %u\n",
619 // max_size, cast_size, attr);
620
621 switch (cast_size) {
622
623 case 8:
624 // IRB.CreateCall(cmplogHookIns1, args);
625 break;
626 case 16:
627 IRB.CreateCall(cmplogHookIns2, args);
628 break;
629 case 32:
630 IRB.CreateCall(cmplogHookIns4, args);
631 break;
632 case 64:
633 IRB.CreateCall(cmplogHookIns8, args);
634 break;
635 case 128:
636 if (max_size == 128) {
637
638 IRB.CreateCall(cmplogHookIns16, args);
639
640 } else {
641
642 IRB.CreateCall(cmplogHookInsN, args);
643
644 }
645
646 break;
647
648 }
649
650 }
651
652 /* else fprintf(stderr, "skipped\n"); */
653
654 ++cur;
655 if (cur >= vector_cnt) { break; }
656
657 }
658
659 }
660
661 }
662
663 if (icomps.size())
664 return true;
665 else
666 return false;
667
668 }
669
670 #if LLVM_MAJOR >= 11 /* use new pass manager */
run(Module & M,ModuleAnalysisManager & MAM)671 PreservedAnalyses CmpLogInstructions::run(Module &M,
672 ModuleAnalysisManager &MAM) {
673
674 #else
675 bool CmpLogInstructions::runOnModule(Module &M) {
676
677 #endif
678
679 if (getenv("AFL_QUIET") == NULL)
680 printf("Running cmplog-instructions-pass by [email protected]\n");
681 else
682 be_quiet = 1;
683 hookInstrs(M);
684 verifyModule(M);
685
686 #if LLVM_MAJOR >= 11 /* use new pass manager */
687 return PreservedAnalyses::all();
688 #else
689 return true;
690 #endif
691
692 }
693
694 #if LLVM_MAJOR < 11 /* use old pass manager */
695 static void registerCmpLogInstructionsPass(const PassManagerBuilder &,
696 legacy::PassManagerBase &PM) {
697
698 auto p = new CmpLogInstructions();
699 PM.add(p);
700
701 }
702
703 static RegisterStandardPasses RegisterCmpLogInstructionsPass(
704 PassManagerBuilder::EP_OptimizerLast, registerCmpLogInstructionsPass);
705
706 static RegisterStandardPasses RegisterCmpLogInstructionsPass0(
707 PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogInstructionsPass);
708
709 #if LLVM_VERSION_MAJOR >= 11
710 static RegisterStandardPasses RegisterCmpLogInstructionsPassLTO(
711 PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
712 registerCmpLogInstructionsPass);
713 #endif
714 #endif
715
716