<lambda>null1 // Copyright 2021 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 package com.code_intelligence.jazzer.instrumentor
16 
17 import org.objectweb.asm.ClassReader
18 import org.objectweb.asm.ClassWriter
19 import org.objectweb.asm.Opcodes
20 import org.objectweb.asm.tree.AbstractInsnNode
21 import org.objectweb.asm.tree.ClassNode
22 import org.objectweb.asm.tree.InsnList
23 import org.objectweb.asm.tree.InsnNode
24 import org.objectweb.asm.tree.IntInsnNode
25 import org.objectweb.asm.tree.LdcInsnNode
26 import org.objectweb.asm.tree.LookupSwitchInsnNode
27 import org.objectweb.asm.tree.MethodInsnNode
28 import org.objectweb.asm.tree.MethodNode
29 import org.objectweb.asm.tree.TableSwitchInsnNode
30 
31 internal class TraceDataFlowInstrumentor(
32     private val types: Set<InstrumentationType>,
33     private val callbackInternalClassName: String = "com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks",
34 ) : Instrumentor {
35 
36     private lateinit var random: DeterministicRandom
37 
38     override fun instrument(internalClassName: String, bytecode: ByteArray): ByteArray {
39         val node = ClassNode()
40         val reader = ClassReader(bytecode)
41         reader.accept(node, 0)
42         random = DeterministicRandom("trace", node.name)
43         for (method in node.methods) {
44             if (shouldInstrument(method)) {
45                 addDataFlowInstrumentation(method)
46             }
47         }
48 
49         val writer = ClassWriter(ClassWriter.COMPUTE_MAXS)
50         node.accept(writer)
51         return writer.toByteArray()
52     }
53 
54     private fun addDataFlowInstrumentation(method: MethodNode) {
55         loop@ for (inst in method.instructions.toArray()) {
56             when (inst.opcode) {
57                 Opcodes.LCMP -> {
58                     if (InstrumentationType.CMP !in types) continue@loop
59                     method.instructions.insertBefore(inst, longCmpInstrumentation())
60                     method.instructions.remove(inst)
61                 }
62                 Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE,
63                 Opcodes.IF_ICMPLT, Opcodes.IF_ICMPLE,
64                 Opcodes.IF_ICMPGT, Opcodes.IF_ICMPGE,
65                 -> {
66                     if (InstrumentationType.CMP !in types) continue@loop
67                     method.instructions.insertBefore(inst, intCmpInstrumentation())
68                 }
69                 Opcodes.IFEQ, Opcodes.IFNE,
70                 Opcodes.IFLT, Opcodes.IFLE,
71                 Opcodes.IFGT, Opcodes.IFGE,
72                 -> {
73                     if (InstrumentationType.CMP !in types) continue@loop
74                     // The IF* opcodes are often used to branch based on the result of a compare
75                     // instruction for a type other than int. The operands of this compare will
76                     // already be reported via the instrumentation above (for non-floating point
77                     // numbers) and the follow-up compare does not provide a good signal as all
78                     // operands will be in {-1, 0, 1}. Skip instrumentation for it.
79                     if (inst.previous?.opcode in listOf(Opcodes.DCMPG, Opcodes.DCMPL, Opcodes.FCMPG, Opcodes.DCMPL) ||
80                         (inst.previous as? MethodInsnNode)?.name == "traceCmpLongWrapper"
81                     ) {
82                         continue@loop
83                     }
84                     method.instructions.insertBefore(inst, ifInstrumentation())
85                 }
86                 Opcodes.LOOKUPSWITCH, Opcodes.TABLESWITCH -> {
87                     if (InstrumentationType.CMP !in types) continue@loop
88                     // Mimic the exclusion logic for small label values in libFuzzer:
89                     // https://github.com/llvm-mirror/compiler-rt/blob/69445f095c22aac2388f939bedebf224a6efcdaf/lib/fuzzer/FuzzerTracePC.cpp#L520
90                     // Case values are reported to libFuzzer via an array of unsigned long values and thus need to be
91                     // sorted by unsigned value.
92                     val caseValues = when (inst) {
93                         is LookupSwitchInsnNode -> {
94                             if (inst.keys.isEmpty() || (0 <= inst.keys.first() && inst.keys.last() < 256)) {
95                                 continue@loop
96                             }
97                             inst.keys
98                         }
99                         is TableSwitchInsnNode -> {
100                             if (0 <= inst.min && inst.max < 256) {
101                                 continue@loop
102                             }
103                             (inst.min..inst.max).filter { caseValue ->
104                                 val index = caseValue - inst.min
105                                 // Filter out "gap cases".
106                                 inst.labels[index].label != inst.dflt.label
107                             }.toList()
108                         }
109                         // Not reached.
110                         else -> continue@loop
111                     }.sortedBy { it.toUInt() }.map { it.toLong() }.toLongArray()
112                     method.instructions.insertBefore(inst, switchInstrumentation(caseValues))
113                 }
114                 Opcodes.IDIV -> {
115                     if (InstrumentationType.DIV !in types) continue@loop
116                     method.instructions.insertBefore(inst, intDivInstrumentation())
117                 }
118                 Opcodes.LDIV -> {
119                     if (InstrumentationType.DIV !in types) continue@loop
120                     method.instructions.insertBefore(inst, longDivInstrumentation())
121                 }
122                 Opcodes.AALOAD, Opcodes.BALOAD,
123                 Opcodes.CALOAD, Opcodes.DALOAD,
124                 Opcodes.FALOAD, Opcodes.IALOAD,
125                 Opcodes.LALOAD, Opcodes.SALOAD,
126                 -> {
127                     if (InstrumentationType.GEP !in types) continue@loop
128                     if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
129                     method.instructions.insertBefore(inst, gepLoadInstrumentation())
130                 }
131                 Opcodes.INVOKEINTERFACE, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEVIRTUAL -> {
132                     if (InstrumentationType.GEP !in types) continue@loop
133                     if (!isGepLoadMethodInsn(inst as MethodInsnNode)) continue@loop
134                     if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
135                     method.instructions.insertBefore(inst, gepLoadInstrumentation())
136                 }
137             }
138         }
139     }
140 
141     private fun InsnList.pushFakePc() {
142         add(LdcInsnNode(random.nextInt(512)))
143     }
144 
145     private fun longCmpInstrumentation() = InsnList().apply {
146         pushFakePc()
147         // traceCmpLong returns the result of the comparison as duplicating two longs on the stack
148         // is not possible without local variables.
149         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpLongWrapper", "(JJI)I", false))
150     }
151 
152     private fun intCmpInstrumentation() = InsnList().apply {
153         add(InsnNode(Opcodes.DUP2))
154         pushFakePc()
155         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpInt", "(III)V", false))
156     }
157 
158     private fun ifInstrumentation() = InsnList().apply {
159         add(InsnNode(Opcodes.DUP))
160         // All if* instructions are compares to the constant 0.
161         add(InsnNode(Opcodes.ICONST_0))
162         add(InsnNode(Opcodes.SWAP))
163         pushFakePc()
164         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceConstCmpInt", "(III)V", false))
165     }
166 
167     private fun intDivInstrumentation() = InsnList().apply {
168         add(InsnNode(Opcodes.DUP))
169         pushFakePc()
170         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivInt", "(II)V", false))
171     }
172 
173     private fun longDivInstrumentation() = InsnList().apply {
174         add(InsnNode(Opcodes.DUP2))
175         pushFakePc()
176         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivLong", "(JI)V", false))
177     }
178 
179     private fun switchInstrumentation(caseValues: LongArray) = InsnList().apply {
180         // duplicate {lookup,table}switch key for use as first function argument
181         add(InsnNode(Opcodes.DUP))
182         add(InsnNode(Opcodes.I2L))
183         // Set up array with switch case values. The format libfuzzer expects is created here directly, i.e., the first
184         // two entries are the number of cases and the bit size of values (always 32).
185         add(IntInsnNode(Opcodes.SIPUSH, caseValues.size + 2))
186         add(IntInsnNode(Opcodes.NEWARRAY, Opcodes.T_LONG))
187         // Store number of cases
188         add(InsnNode(Opcodes.DUP))
189         add(IntInsnNode(Opcodes.SIPUSH, 0))
190         add(LdcInsnNode(caseValues.size.toLong()))
191         add(InsnNode(Opcodes.LASTORE))
192         // Store bit size of keys
193         add(InsnNode(Opcodes.DUP))
194         add(IntInsnNode(Opcodes.SIPUSH, 1))
195         add(LdcInsnNode(32.toLong()))
196         add(InsnNode(Opcodes.LASTORE))
197         // Store {lookup,table}switch case values
198         for ((i, caseValue) in caseValues.withIndex()) {
199             add(InsnNode(Opcodes.DUP))
200             add(IntInsnNode(Opcodes.SIPUSH, 2 + i))
201             add(LdcInsnNode(caseValue))
202             add(InsnNode(Opcodes.LASTORE))
203         }
204         pushFakePc()
205         // call the native callback function
206         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceSwitch", "(J[JI)V", false))
207     }
208 
209     /**
210      * Returns true if [node] represents an instruction that possibly pushes a valid, non-zero, constant array index
211      * onto the stack.
212      */
213     private fun isConstantIntegerPushInsn(node: AbstractInsnNode?) = node?.opcode in CONSTANT_INTEGER_PUSH_OPCODES
214 
215     /**
216      * Returns true if [node] represents a call to a method that performs an indexed lookup into an array-like
217      * structure.
218      */
219     private fun isGepLoadMethodInsn(node: MethodInsnNode): Boolean {
220         if (!node.desc.startsWith("(I)")) return false
221         val returnType = node.desc.removePrefix("(I)")
222         return MethodInfo(node.owner, node.name, returnType) in GEP_LOAD_METHODS
223     }
224 
225     private fun gepLoadInstrumentation() = InsnList().apply {
226         // Duplicate the index and convert to long.
227         add(InsnNode(Opcodes.DUP))
228         add(InsnNode(Opcodes.I2L))
229         pushFakePc()
230         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceGep", "(JI)V", false))
231     }
232 
233     companion object {
234         // Low constants (0, 1) are omitted as they create a lot of noise.
235         val CONSTANT_INTEGER_PUSH_OPCODES = listOf(
236             Opcodes.BIPUSH,
237             Opcodes.SIPUSH,
238             Opcodes.LDC,
239             Opcodes.ICONST_2,
240             Opcodes.ICONST_3,
241             Opcodes.ICONST_4,
242             Opcodes.ICONST_5,
243         )
244 
245         data class MethodInfo(val internalClassName: String, val name: String, val returnType: String)
246 
247         val GEP_LOAD_METHODS = setOf(
248             MethodInfo("java/util/AbstractList", "get", "Ljava/lang/Object;"),
249             MethodInfo("java/util/ArrayList", "get", "Ljava/lang/Object;"),
250             MethodInfo("java/util/List", "get", "Ljava/lang/Object;"),
251             MethodInfo("java/util/Stack", "get", "Ljava/lang/Object;"),
252             MethodInfo("java/util/Vector", "get", "Ljava/lang/Object;"),
253             MethodInfo("java/lang/CharSequence", "charAt", "C"),
254             MethodInfo("java/lang/String", "charAt", "C"),
255             MethodInfo("java/lang/StringBuffer", "charAt", "C"),
256             MethodInfo("java/lang/StringBuilder", "charAt", "C"),
257             MethodInfo("java/lang/String", "codePointAt", "I"),
258             MethodInfo("java/lang/String", "codePointBefore", "I"),
259             MethodInfo("java/nio/ByteBuffer", "get", "B"),
260             MethodInfo("java/nio/ByteBuffer", "getChar", "C"),
261             MethodInfo("java/nio/ByteBuffer", "getDouble", "D"),
262             MethodInfo("java/nio/ByteBuffer", "getFloat", "F"),
263             MethodInfo("java/nio/ByteBuffer", "getInt", "I"),
264             MethodInfo("java/nio/ByteBuffer", "getLong", "J"),
265             MethodInfo("java/nio/ByteBuffer", "getShort", "S"),
266         )
267     }
268 }
269