<lambda>null1 // Copyright 2021 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package com.code_intelligence.jazzer.instrumentor
16
17 import org.objectweb.asm.ClassReader
18 import org.objectweb.asm.ClassWriter
19 import org.objectweb.asm.Opcodes
20 import org.objectweb.asm.tree.AbstractInsnNode
21 import org.objectweb.asm.tree.ClassNode
22 import org.objectweb.asm.tree.InsnList
23 import org.objectweb.asm.tree.InsnNode
24 import org.objectweb.asm.tree.IntInsnNode
25 import org.objectweb.asm.tree.LdcInsnNode
26 import org.objectweb.asm.tree.LookupSwitchInsnNode
27 import org.objectweb.asm.tree.MethodInsnNode
28 import org.objectweb.asm.tree.MethodNode
29 import org.objectweb.asm.tree.TableSwitchInsnNode
30
31 internal class TraceDataFlowInstrumentor(
32 private val types: Set<InstrumentationType>,
33 private val callbackInternalClassName: String = "com/code_intelligence/jazzer/runtime/TraceDataFlowNativeCallbacks",
34 ) : Instrumentor {
35
36 private lateinit var random: DeterministicRandom
37
38 override fun instrument(internalClassName: String, bytecode: ByteArray): ByteArray {
39 val node = ClassNode()
40 val reader = ClassReader(bytecode)
41 reader.accept(node, 0)
42 random = DeterministicRandom("trace", node.name)
43 for (method in node.methods) {
44 if (shouldInstrument(method)) {
45 addDataFlowInstrumentation(method)
46 }
47 }
48
49 val writer = ClassWriter(ClassWriter.COMPUTE_MAXS)
50 node.accept(writer)
51 return writer.toByteArray()
52 }
53
54 private fun addDataFlowInstrumentation(method: MethodNode) {
55 loop@ for (inst in method.instructions.toArray()) {
56 when (inst.opcode) {
57 Opcodes.LCMP -> {
58 if (InstrumentationType.CMP !in types) continue@loop
59 method.instructions.insertBefore(inst, longCmpInstrumentation())
60 method.instructions.remove(inst)
61 }
62 Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE,
63 Opcodes.IF_ICMPLT, Opcodes.IF_ICMPLE,
64 Opcodes.IF_ICMPGT, Opcodes.IF_ICMPGE,
65 -> {
66 if (InstrumentationType.CMP !in types) continue@loop
67 method.instructions.insertBefore(inst, intCmpInstrumentation())
68 }
69 Opcodes.IFEQ, Opcodes.IFNE,
70 Opcodes.IFLT, Opcodes.IFLE,
71 Opcodes.IFGT, Opcodes.IFGE,
72 -> {
73 if (InstrumentationType.CMP !in types) continue@loop
74 // The IF* opcodes are often used to branch based on the result of a compare
75 // instruction for a type other than int. The operands of this compare will
76 // already be reported via the instrumentation above (for non-floating point
77 // numbers) and the follow-up compare does not provide a good signal as all
78 // operands will be in {-1, 0, 1}. Skip instrumentation for it.
79 if (inst.previous?.opcode in listOf(Opcodes.DCMPG, Opcodes.DCMPL, Opcodes.FCMPG, Opcodes.DCMPL) ||
80 (inst.previous as? MethodInsnNode)?.name == "traceCmpLongWrapper"
81 ) {
82 continue@loop
83 }
84 method.instructions.insertBefore(inst, ifInstrumentation())
85 }
86 Opcodes.LOOKUPSWITCH, Opcodes.TABLESWITCH -> {
87 if (InstrumentationType.CMP !in types) continue@loop
88 // Mimic the exclusion logic for small label values in libFuzzer:
89 // https://github.com/llvm-mirror/compiler-rt/blob/69445f095c22aac2388f939bedebf224a6efcdaf/lib/fuzzer/FuzzerTracePC.cpp#L520
90 // Case values are reported to libFuzzer via an array of unsigned long values and thus need to be
91 // sorted by unsigned value.
92 val caseValues = when (inst) {
93 is LookupSwitchInsnNode -> {
94 if (inst.keys.isEmpty() || (0 <= inst.keys.first() && inst.keys.last() < 256)) {
95 continue@loop
96 }
97 inst.keys
98 }
99 is TableSwitchInsnNode -> {
100 if (0 <= inst.min && inst.max < 256) {
101 continue@loop
102 }
103 (inst.min..inst.max).filter { caseValue ->
104 val index = caseValue - inst.min
105 // Filter out "gap cases".
106 inst.labels[index].label != inst.dflt.label
107 }.toList()
108 }
109 // Not reached.
110 else -> continue@loop
111 }.sortedBy { it.toUInt() }.map { it.toLong() }.toLongArray()
112 method.instructions.insertBefore(inst, switchInstrumentation(caseValues))
113 }
114 Opcodes.IDIV -> {
115 if (InstrumentationType.DIV !in types) continue@loop
116 method.instructions.insertBefore(inst, intDivInstrumentation())
117 }
118 Opcodes.LDIV -> {
119 if (InstrumentationType.DIV !in types) continue@loop
120 method.instructions.insertBefore(inst, longDivInstrumentation())
121 }
122 Opcodes.AALOAD, Opcodes.BALOAD,
123 Opcodes.CALOAD, Opcodes.DALOAD,
124 Opcodes.FALOAD, Opcodes.IALOAD,
125 Opcodes.LALOAD, Opcodes.SALOAD,
126 -> {
127 if (InstrumentationType.GEP !in types) continue@loop
128 if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
129 method.instructions.insertBefore(inst, gepLoadInstrumentation())
130 }
131 Opcodes.INVOKEINTERFACE, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEVIRTUAL -> {
132 if (InstrumentationType.GEP !in types) continue@loop
133 if (!isGepLoadMethodInsn(inst as MethodInsnNode)) continue@loop
134 if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
135 method.instructions.insertBefore(inst, gepLoadInstrumentation())
136 }
137 }
138 }
139 }
140
141 private fun InsnList.pushFakePc() {
142 add(LdcInsnNode(random.nextInt(512)))
143 }
144
145 private fun longCmpInstrumentation() = InsnList().apply {
146 pushFakePc()
147 // traceCmpLong returns the result of the comparison as duplicating two longs on the stack
148 // is not possible without local variables.
149 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpLongWrapper", "(JJI)I", false))
150 }
151
152 private fun intCmpInstrumentation() = InsnList().apply {
153 add(InsnNode(Opcodes.DUP2))
154 pushFakePc()
155 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpInt", "(III)V", false))
156 }
157
158 private fun ifInstrumentation() = InsnList().apply {
159 add(InsnNode(Opcodes.DUP))
160 // All if* instructions are compares to the constant 0.
161 add(InsnNode(Opcodes.ICONST_0))
162 add(InsnNode(Opcodes.SWAP))
163 pushFakePc()
164 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceConstCmpInt", "(III)V", false))
165 }
166
167 private fun intDivInstrumentation() = InsnList().apply {
168 add(InsnNode(Opcodes.DUP))
169 pushFakePc()
170 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivInt", "(II)V", false))
171 }
172
173 private fun longDivInstrumentation() = InsnList().apply {
174 add(InsnNode(Opcodes.DUP2))
175 pushFakePc()
176 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivLong", "(JI)V", false))
177 }
178
179 private fun switchInstrumentation(caseValues: LongArray) = InsnList().apply {
180 // duplicate {lookup,table}switch key for use as first function argument
181 add(InsnNode(Opcodes.DUP))
182 add(InsnNode(Opcodes.I2L))
183 // Set up array with switch case values. The format libfuzzer expects is created here directly, i.e., the first
184 // two entries are the number of cases and the bit size of values (always 32).
185 add(IntInsnNode(Opcodes.SIPUSH, caseValues.size + 2))
186 add(IntInsnNode(Opcodes.NEWARRAY, Opcodes.T_LONG))
187 // Store number of cases
188 add(InsnNode(Opcodes.DUP))
189 add(IntInsnNode(Opcodes.SIPUSH, 0))
190 add(LdcInsnNode(caseValues.size.toLong()))
191 add(InsnNode(Opcodes.LASTORE))
192 // Store bit size of keys
193 add(InsnNode(Opcodes.DUP))
194 add(IntInsnNode(Opcodes.SIPUSH, 1))
195 add(LdcInsnNode(32.toLong()))
196 add(InsnNode(Opcodes.LASTORE))
197 // Store {lookup,table}switch case values
198 for ((i, caseValue) in caseValues.withIndex()) {
199 add(InsnNode(Opcodes.DUP))
200 add(IntInsnNode(Opcodes.SIPUSH, 2 + i))
201 add(LdcInsnNode(caseValue))
202 add(InsnNode(Opcodes.LASTORE))
203 }
204 pushFakePc()
205 // call the native callback function
206 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceSwitch", "(J[JI)V", false))
207 }
208
209 /**
210 * Returns true if [node] represents an instruction that possibly pushes a valid, non-zero, constant array index
211 * onto the stack.
212 */
213 private fun isConstantIntegerPushInsn(node: AbstractInsnNode?) = node?.opcode in CONSTANT_INTEGER_PUSH_OPCODES
214
215 /**
216 * Returns true if [node] represents a call to a method that performs an indexed lookup into an array-like
217 * structure.
218 */
219 private fun isGepLoadMethodInsn(node: MethodInsnNode): Boolean {
220 if (!node.desc.startsWith("(I)")) return false
221 val returnType = node.desc.removePrefix("(I)")
222 return MethodInfo(node.owner, node.name, returnType) in GEP_LOAD_METHODS
223 }
224
225 private fun gepLoadInstrumentation() = InsnList().apply {
226 // Duplicate the index and convert to long.
227 add(InsnNode(Opcodes.DUP))
228 add(InsnNode(Opcodes.I2L))
229 pushFakePc()
230 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceGep", "(JI)V", false))
231 }
232
233 companion object {
234 // Low constants (0, 1) are omitted as they create a lot of noise.
235 val CONSTANT_INTEGER_PUSH_OPCODES = listOf(
236 Opcodes.BIPUSH,
237 Opcodes.SIPUSH,
238 Opcodes.LDC,
239 Opcodes.ICONST_2,
240 Opcodes.ICONST_3,
241 Opcodes.ICONST_4,
242 Opcodes.ICONST_5,
243 )
244
245 data class MethodInfo(val internalClassName: String, val name: String, val returnType: String)
246
247 val GEP_LOAD_METHODS = setOf(
248 MethodInfo("java/util/AbstractList", "get", "Ljava/lang/Object;"),
249 MethodInfo("java/util/ArrayList", "get", "Ljava/lang/Object;"),
250 MethodInfo("java/util/List", "get", "Ljava/lang/Object;"),
251 MethodInfo("java/util/Stack", "get", "Ljava/lang/Object;"),
252 MethodInfo("java/util/Vector", "get", "Ljava/lang/Object;"),
253 MethodInfo("java/lang/CharSequence", "charAt", "C"),
254 MethodInfo("java/lang/String", "charAt", "C"),
255 MethodInfo("java/lang/StringBuffer", "charAt", "C"),
256 MethodInfo("java/lang/StringBuilder", "charAt", "C"),
257 MethodInfo("java/lang/String", "codePointAt", "I"),
258 MethodInfo("java/lang/String", "codePointBefore", "I"),
259 MethodInfo("java/nio/ByteBuffer", "get", "B"),
260 MethodInfo("java/nio/ByteBuffer", "getChar", "C"),
261 MethodInfo("java/nio/ByteBuffer", "getDouble", "D"),
262 MethodInfo("java/nio/ByteBuffer", "getFloat", "F"),
263 MethodInfo("java/nio/ByteBuffer", "getInt", "I"),
264 MethodInfo("java/nio/ByteBuffer", "getLong", "J"),
265 MethodInfo("java/nio/ByteBuffer", "getShort", "S"),
266 )
267 }
268 }
269