xref: /aosp_15_r20/external/jazzer-api/src/main/java/com/code_intelligence/jazzer/runtime/TraceCmpHooks.java (revision 33edd6723662ea34453766bfdca85dbfdd5342b8)
1 // Copyright 2021 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 package com.code_intelligence.jazzer.runtime;
16 
17 import com.code_intelligence.jazzer.api.HookType;
18 import com.code_intelligence.jazzer.api.MethodHook;
19 import java.lang.invoke.MethodHandle;
20 import java.util.*;
21 
22 @SuppressWarnings("unused")
23 final public class TraceCmpHooks {
24   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte", targetMethod = "compare",
25       targetMethodDescriptor = "(BB)I")
26   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte",
27       targetMethod = "compareUnsigned", targetMethodDescriptor = "(BB)I")
28   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short", targetMethod = "compare",
29       targetMethodDescriptor = "(SS)I")
30   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short",
31       targetMethod = "compareUnsigned", targetMethodDescriptor = "(SS)I")
32   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
33       targetMethod = "compare", targetMethodDescriptor = "(II)I")
34   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
35       targetMethod = "compareUnsigned", targetMethodDescriptor = "(II)I")
36   @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ",
37       targetMethod = "compare", targetMethodDescriptor = "(II)I")
38   public static void
integerCompare(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId)39   integerCompare(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) {
40     TraceDataFlowNativeCallbacks.traceCmpInt((int) arguments[0], (int) arguments[1], hookId);
41   }
42 
43   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte",
44       targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Byte;)I")
45   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short",
46       targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Short;)I")
47   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
48       targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Integer;)I")
49   public static void
integerCompareTo(MethodHandle method, Object thisObject, Object[] arguments, int hookId)50   integerCompareTo(MethodHandle method, Object thisObject, Object[] arguments, int hookId) {
51     TraceDataFlowNativeCallbacks.traceCmpInt((int) thisObject, (int) arguments[0], hookId);
52   }
53 
54   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long", targetMethod = "compare",
55       targetMethodDescriptor = "(JJ)I")
56   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long",
57       targetMethod = "compareUnsigned", targetMethodDescriptor = "(JJ)I")
58   public static void
longCompare(MethodHandle method, Object thisObject, Object[] arguments, int hookId)59   longCompare(MethodHandle method, Object thisObject, Object[] arguments, int hookId) {
60     TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId);
61   }
62 
63   @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long",
64       targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Long;)I")
65   public static void
longCompareTo(MethodHandle method, Long thisObject, Object[] arguments, int hookId)66   longCompareTo(MethodHandle method, Long thisObject, Object[] arguments, int hookId) {
67     TraceDataFlowNativeCallbacks.traceCmpLong(thisObject, (long) arguments[0], hookId);
68   }
69 
70   @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ",
71       targetMethod = "compare", targetMethodDescriptor = "(JJ)I")
72   public static void
longCompareKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId)73   longCompareKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) {
74     TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId);
75   }
76 
77   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "equals")
78   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
79       targetMethod = "equalsIgnoreCase")
80   public static void
equals(MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean areEqual)81   equals(MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean areEqual) {
82     if (!areEqual && arguments.length == 1 && arguments[0] instanceof String) {
83       // The precise value of the result of the comparison is not used by libFuzzer as long as it is
84       // non-zero.
85       TraceDataFlowNativeCallbacks.traceStrcmp(thisObject, (String) arguments[0], 1, hookId);
86     }
87   }
88 
89   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Object", targetMethod = "equals")
90   @MethodHook(
91       type = HookType.AFTER, targetClassName = "java.lang.CharSequence", targetMethod = "equals")
92   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Number", targetMethod = "equals")
93   public static void
genericEquals( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual)94   genericEquals(
95       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) {
96     if (!areEqual && arguments.length == 1 && arguments[0] != null
97         && thisObject.getClass() == arguments[0].getClass()) {
98       TraceDataFlowNativeCallbacks.traceGenericCmp(thisObject, arguments[0], hookId);
99     }
100   }
101 
102   @MethodHook(type = HookType.AFTER, targetClassName = "clojure.lang.Util", targetMethod = "equiv")
genericStaticEquals( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual)103   public static void genericStaticEquals(
104       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) {
105     if (!areEqual && arguments.length == 2 && arguments[0] != null && arguments[1] != null
106         && arguments[1].getClass() == arguments[0].getClass()) {
107       TraceDataFlowNativeCallbacks.traceGenericCmp(arguments[0], arguments[1], hookId);
108     }
109   }
110 
111   @MethodHook(
112       type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "compareTo")
113   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
114       targetMethod = "compareToIgnoreCase")
115   public static void
compareTo( MethodHandle method, String thisObject, Object[] arguments, int hookId, Integer returnValue)116   compareTo(
117       MethodHandle method, String thisObject, Object[] arguments, int hookId, Integer returnValue) {
118     if (returnValue != 0 && arguments.length == 1 && arguments[0] instanceof String) {
119       TraceDataFlowNativeCallbacks.traceStrcmp(
120           thisObject, (String) arguments[0], returnValue, hookId);
121     }
122   }
123 
124   @MethodHook(
125       type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contentEquals")
126   public static void
contentEquals(MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean areEqualContents)127   contentEquals(MethodHandle method, String thisObject, Object[] arguments, int hookId,
128       Boolean areEqualContents) {
129     if (!areEqualContents && arguments.length == 1 && arguments[0] instanceof CharSequence) {
130       TraceDataFlowNativeCallbacks.traceStrcmp(
131           thisObject, ((CharSequence) arguments[0]).toString(), 1, hookId);
132     }
133   }
134 
135   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
136       targetMethod = "regionMatches", targetMethodDescriptor = "(ZILjava/lang/String;II)Z")
137   public static void
regionsMatches5( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue)138   regionsMatches5(
139       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
140     if (!returnValue) {
141       int toffset = (int) arguments[1];
142       String other = (String) arguments[2];
143       int ooffset = (int) arguments[3];
144       int len = (int) arguments[4];
145       regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId);
146     }
147   }
148 
149   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
150       targetMethod = "regionMatches", targetMethodDescriptor = "(ILjava/lang/String;II)Z")
151   public static void
regionMatches4( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue)152   regionMatches4(
153       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
154     if (!returnValue) {
155       int toffset = (int) arguments[0];
156       String other = (String) arguments[1];
157       int ooffset = (int) arguments[2];
158       int len = (int) arguments[3];
159       regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId);
160     }
161   }
162 
regionMatchesInternal( String thisString, int toffset, String other, int ooffset, int len, int hookId)163   private static void regionMatchesInternal(
164       String thisString, int toffset, String other, int ooffset, int len, int hookId) {
165     if (toffset < 0 || ooffset < 0)
166       return;
167     int cappedThisStringEnd = Math.min(toffset + len, thisString.length());
168     int cappedOtherStringEnd = Math.min(ooffset + len, other.length());
169     String thisPart = thisString.substring(toffset, cappedThisStringEnd);
170     String otherPart = other.substring(ooffset, cappedOtherStringEnd);
171     TraceDataFlowNativeCallbacks.traceStrcmp(thisPart, otherPart, 1, hookId);
172   }
173 
174   @MethodHook(
175       type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contains")
176   public static void
contains( MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean doesContain)177   contains(
178       MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean doesContain) {
179     if (!doesContain && arguments.length == 1 && arguments[0] instanceof CharSequence) {
180       TraceDataFlowNativeCallbacks.traceStrstr(
181           thisObject, ((CharSequence) arguments[0]).toString(), hookId);
182     }
183   }
184 
185   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "indexOf")
186   @MethodHook(
187       type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "lastIndexOf")
188   @MethodHook(
189       type = HookType.AFTER, targetClassName = "java.lang.StringBuffer", targetMethod = "indexOf")
190   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuffer",
191       targetMethod = "lastIndexOf")
192   @MethodHook(
193       type = HookType.AFTER, targetClassName = "java.lang.StringBuilder", targetMethod = "indexOf")
194   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuilder",
195       targetMethod = "lastIndexOf")
196   public static void
indexOf( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue)197   indexOf(
198       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
199     if (returnValue == -1 && arguments.length >= 1 && arguments[0] instanceof String) {
200       TraceDataFlowNativeCallbacks.traceStrstr(
201           thisObject.toString(), (String) arguments[0], hookId);
202     }
203   }
204 
205   @MethodHook(
206       type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "startsWith")
207   @MethodHook(
208       type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "endsWith")
209   public static void
startsWith(MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean doesStartOrEndsWith)210   startsWith(MethodHandle method, String thisObject, Object[] arguments, int hookId,
211       Boolean doesStartOrEndsWith) {
212     if (!doesStartOrEndsWith && arguments.length >= 1 && arguments[0] instanceof String) {
213       TraceDataFlowNativeCallbacks.traceStrstr(thisObject, (String) arguments[0], hookId);
214     }
215   }
216 
217   @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "replace",
218       targetMethodDescriptor =
219           "(Ljava/lang/CharSequence;Ljava/lang/CharSequence;)Ljava/lang/String;")
220   public static void
replace( MethodHandle method, Object thisObject, Object[] arguments, int hookId, String returnValue)221   replace(
222       MethodHandle method, Object thisObject, Object[] arguments, int hookId, String returnValue) {
223     String original = (String) thisObject;
224     // Report only if the replacement was not successful.
225     if (original.equals(returnValue)) {
226       String target = arguments[0].toString();
227       TraceDataFlowNativeCallbacks.traceStrstr(original, target, hookId);
228     }
229   }
230 
231   // For standard Kotlin packages, which are named according to the pattern kotlin.*, we append a
232   // whitespace to the package name of the target class so that they are not mangled due to shading.
233   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.jvm.internal.Intrinsics ",
234       targetMethod = "areEqual")
235   @MethodHook(
236       type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "equals")
237   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
238       targetMethod = "equals$default")
239   public static void
equalsKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean equalStrings)240   equalsKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
241       Boolean equalStrings) {
242     if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof String
243         && arguments[1] instanceof String) {
244       TraceDataFlowNativeCallbacks.traceStrcmp(
245           (String) arguments[0], (String) arguments[1], 1, hookId);
246     }
247   }
248 
249   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
250       targetMethod = "contentEquals")
251   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
252       targetMethod = "contentEquals$default")
253   public static void
contentEqualKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean equalStrings)254   contentEqualKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
255       Boolean equalStrings) {
256     if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof CharSequence
257         && arguments[1] instanceof CharSequence) {
258       TraceDataFlowNativeCallbacks.traceStrcmp(
259           arguments[0].toString(), arguments[1].toString(), 1, hookId);
260     }
261   }
262 
263   @MethodHook(
264       type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "compareTo")
265   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
266       targetMethod = "compareTo$default")
267   public static void
compareToKt( MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue)268   compareToKt(
269       MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
270     if (returnValue != 0 && arguments.length >= 2 && arguments[0] instanceof String
271         && arguments[1] instanceof String) {
272       TraceDataFlowNativeCallbacks.traceStrcmp(
273           (String) arguments[0], (String) arguments[1], 1, hookId);
274     }
275   }
276 
277   @MethodHook(
278       type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "endsWith")
279   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
280       targetMethod = "endsWith$default")
281   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
282       targetMethod = "startsWith")
283   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
284       targetMethod = "startsWith$default")
285   public static void
startsWithKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean doesStartOrEndsWith)286   startsWithKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
287       Boolean doesStartOrEndsWith) {
288     if (!doesStartOrEndsWith && arguments.length >= 2 && arguments[0] instanceof CharSequence
289         && arguments[1] instanceof CharSequence) {
290       TraceDataFlowNativeCallbacks.traceStrstr(
291           arguments[0].toString(), arguments[1].toString(), hookId);
292     }
293   }
294 
295   @MethodHook(
296       type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "contains")
297   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
298       targetMethod = "contains$default")
299   public static void
containsKt( MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean doesContain)300   containsKt(
301       MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean doesContain) {
302     if (!doesContain && arguments.length >= 2 && arguments[0] instanceof CharSequence
303         && arguments[1] instanceof CharSequence) {
304       TraceDataFlowNativeCallbacks.traceStrstr(
305           arguments[0].toString(), arguments[1].toString(), hookId);
306     }
307   }
308 
309   @MethodHook(
310       type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "indexOf")
311   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
312       targetMethod = "indexOf$default")
313   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
314       targetMethod = "lastIndexOf")
315   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
316       targetMethod = "lastIndexOf$default")
317   public static void
indexOfKt( MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue)318   indexOfKt(
319       MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
320     if (returnValue != -1 || arguments.length < 2 || !(arguments[0] instanceof CharSequence)) {
321       return;
322     }
323     if (arguments[1] instanceof String) {
324       TraceDataFlowNativeCallbacks.traceStrstr(
325           arguments[0].toString(), (String) arguments[1], hookId);
326     } else if (arguments[1] instanceof Character) {
327       TraceDataFlowNativeCallbacks.traceStrstr(
328           arguments[0].toString(), ((Character) arguments[1]).toString(), hookId);
329     }
330   }
331 
332   @MethodHook(
333       type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "replace")
334   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
335       targetMethod = "replace$default")
336   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
337       targetMethod = "replaceAfter")
338   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
339       targetMethod = "replaceAfter$default")
340   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
341       targetMethod = "replaceAfterLast")
342   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
343       targetMethod = "replaceAfterLast$default")
344   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
345       targetMethod = "replaceBefore")
346   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
347       targetMethod = "replaceBefore$default")
348   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
349       targetMethod = "replaceBeforeLast")
350   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
351       targetMethod = "replaceBeforeLast$default")
352   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
353       targetMethod = "replaceFirst")
354   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
355       targetMethod = "replaceFirst$default")
356   public static void
replaceKt( MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, String returnValue)357   replaceKt(
358       MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, String returnValue) {
359     if (arguments.length < 2 || !(arguments[0] instanceof String)) {
360       return;
361     }
362     String original = (String) arguments[0];
363     if (!original.equals(returnValue)) {
364       return;
365     }
366 
367     // We currently don't handle the overloads that take a regex as a second argument.
368     if (arguments[1] instanceof String || arguments[1] instanceof Character) {
369       TraceDataFlowNativeCallbacks.traceStrstr(original, arguments[1].toString(), hookId);
370     }
371   }
372 
373   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
374       targetMethod = "regionMatches",
375       targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZ)Z")
376   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
377       targetMethod = "regionMatches$default",
378       targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZILjava/lang/Object;)Z")
379   public static void
regionMatchesKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean doesRegionMatch)380   regionMatchesKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
381       Boolean doesRegionMatch) {
382     if (!doesRegionMatch) {
383       String thisString = arguments[0].toString();
384       int thisOffset = (int) arguments[1];
385       String other = arguments[2].toString();
386       int otherOffset = (int) arguments[3];
387       int length = (int) arguments[4];
388       regionMatchesInternal(thisString, thisOffset, other, otherOffset, length, hookId);
389     }
390   }
391 
392   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
393       targetMethod = "indexOfAny")
394   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
395       targetMethod = "indexOfAny$default")
396   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
397       targetMethod = "lastIndexOfAny")
398   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
399       targetMethod = "lastIndexOfAny$default")
400   public static void
indexOfAnyKt( MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue)401   indexOfAnyKt(
402       MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
403     if (returnValue == -1 && arguments.length >= 2 && arguments[0] instanceof CharSequence) {
404       guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId);
405     }
406   }
407 
408   @MethodHook(
409       type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "findAnyOf")
410   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
411       targetMethod = "findAnyOf$default")
412   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
413       targetMethod = "findLastAnyOf")
414   @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
415       targetMethod = "findLastAnyOf$default")
416   public static void
findAnyKt( MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Object returnValue)417   findAnyKt(
418       MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Object returnValue) {
419     if (returnValue == null && arguments.length >= 2 && arguments[0] instanceof CharSequence) {
420       guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId);
421     }
422   }
423 
guideTowardContainmentOfFirstElement( String containingString, Object candidateCollectionObj, int hookId)424   private static void guideTowardContainmentOfFirstElement(
425       String containingString, Object candidateCollectionObj, int hookId) {
426     if (candidateCollectionObj instanceof Collection<?>) {
427       Collection<?> strings = (Collection<?>) candidateCollectionObj;
428       if (strings.isEmpty()) {
429         return;
430       }
431       Object firstElementObj = strings.iterator().next();
432       if (firstElementObj instanceof CharSequence) {
433         TraceDataFlowNativeCallbacks.traceStrstr(
434             containingString, firstElementObj.toString(), hookId);
435       }
436     } else if (candidateCollectionObj.getClass().isArray()) {
437       if (candidateCollectionObj.getClass().getComponentType() == char.class) {
438         char[] chars = (char[]) candidateCollectionObj;
439         if (chars.length > 0) {
440           TraceDataFlowNativeCallbacks.traceStrstr(
441               containingString, Character.toString(chars[0]), hookId);
442         }
443       }
444     }
445   }
446 
447   @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals",
448       targetMethodDescriptor = "([B[B)Z")
449   public static void
arraysEquals( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue)450   arraysEquals(
451       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
452     if (returnValue)
453       return;
454     byte[] first = (byte[]) arguments[0];
455     byte[] second = (byte[]) arguments[1];
456     TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId);
457   }
458 
459   @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals",
460       targetMethodDescriptor = "([BII[BII)Z")
461   public static void
arraysEqualsRange( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue)462   arraysEqualsRange(
463       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
464     if (returnValue)
465       return;
466     byte[] first =
467         Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]);
468     byte[] second =
469         Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]);
470     TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId);
471   }
472 
473   @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare",
474       targetMethodDescriptor = "([B[B)I")
475   @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays",
476       targetMethod = "compareUnsigned", targetMethodDescriptor = "([B[B)I")
477   public static void
arraysCompare( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue)478   arraysCompare(
479       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
480     if (returnValue == 0)
481       return;
482     byte[] first = (byte[]) arguments[0];
483     byte[] second = (byte[]) arguments[1];
484     TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId);
485   }
486 
487   @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare",
488       targetMethodDescriptor = "([BII[BII)I")
489   @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays",
490       targetMethod = "compareUnsigned", targetMethodDescriptor = "([BII[BII)I")
491   public static void
arraysCompareRange( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue)492   arraysCompareRange(
493       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
494     if (returnValue == 0)
495       return;
496     byte[] first =
497         Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]);
498     byte[] second =
499         Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]);
500     TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId);
501   }
502 
503   // The maximal number of elements of a non-TreeMap Map that will be sorted and searched for the
504   // key closest to the current lookup key in the mapGet hook.
505   private static final int MAX_NUM_KEYS_TO_ENUMERATE = 100;
506 
507   @SuppressWarnings({"rawtypes", "unchecked"})
508   @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Map", targetMethod = "get")
mapGet( MethodHandle method, Object thisObject, Object[] arguments, int hookId, Object returnValue)509   public static void mapGet(
510       MethodHandle method, Object thisObject, Object[] arguments, int hookId, Object returnValue) {
511     if (returnValue != null)
512       return;
513     if (arguments.length != 1) {
514       return;
515     }
516     if (thisObject == null)
517       return;
518     final Map map = (Map) thisObject;
519     if (map.size() == 0)
520       return;
521     final Object currentKey = arguments[0];
522     if (currentKey == null)
523       return;
524     // Find two valid map keys that bracket currentKey.
525     // This is a generalization of libFuzzer's __sanitizer_cov_trace_switch:
526     // https://github.com/llvm/llvm-project/blob/318942de229beb3b2587df09e776a50327b5cef0/compiler-rt/lib/fuzzer/FuzzerTracePC.cpp#L564
527     Object lowerBoundKey = null;
528     Object upperBoundKey = null;
529     try {
530       if (map instanceof TreeMap) {
531         final TreeMap treeMap = (TreeMap) map;
532         try {
533           lowerBoundKey = treeMap.floorKey(currentKey);
534           upperBoundKey = treeMap.ceilingKey(currentKey);
535         } catch (ClassCastException ignored) {
536           // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
537           // compared to the maps keys.
538         }
539       } else if (currentKey instanceof Comparable) {
540         final Comparable comparableCurrentKey = (Comparable) currentKey;
541         // Find two keys that bracket currentKey.
542         // Note: This is not deterministic if map.size() > MAX_NUM_KEYS_TO_ENUMERATE.
543         int enumeratedKeys = 0;
544         for (Object validKey : map.keySet()) {
545           if (!(validKey instanceof Comparable))
546             continue;
547           final Comparable comparableValidKey = (Comparable) validKey;
548           // If the key sorts lower than the non-existing key, but higher than the current lower
549           // bound, update the lower bound and vice versa for the upper bound.
550           try {
551             if (comparableValidKey.compareTo(comparableCurrentKey) < 0
552                 && (lowerBoundKey == null || comparableValidKey.compareTo(lowerBoundKey) > 0)) {
553               lowerBoundKey = validKey;
554             }
555             if (comparableValidKey.compareTo(comparableCurrentKey) > 0
556                 && (upperBoundKey == null || comparableValidKey.compareTo(upperBoundKey) < 0)) {
557               upperBoundKey = validKey;
558             }
559           } catch (ClassCastException ignored) {
560             // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
561             // compared to the maps keys.
562           }
563           if (enumeratedKeys++ > MAX_NUM_KEYS_TO_ENUMERATE)
564             break;
565         }
566       }
567     } catch (ConcurrentModificationException ignored) {
568       // map was modified by another thread, skip this invocation
569       return;
570     }
571     // Modify the hook ID so that compares against distinct valid keys are traced separately.
572     if (lowerBoundKey != null) {
573       TraceDataFlowNativeCallbacks.traceGenericCmp(
574           currentKey, lowerBoundKey, hookId + lowerBoundKey.hashCode());
575     }
576     if (upperBoundKey != null) {
577       TraceDataFlowNativeCallbacks.traceGenericCmp(
578           currentKey, upperBoundKey, hookId + upperBoundKey.hashCode());
579     }
580   }
581 
582   @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
583       targetMethod = "assertNotEquals",
584       targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;)V")
585   @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
586       targetMethod = "assertNotEquals",
587       targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/String;)V")
588   @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
589       targetMethod = "assertNotEquals",
590       targetMethodDescriptor =
591           "(Ljava/lang/Object;Ljava/lang/Object;Ljava/util/function/Supplier;)V")
592   public static void
assertEquals(MethodHandle method, Object node, Object[] args, int hookId, Object alwaysNull)593   assertEquals(MethodHandle method, Object node, Object[] args, int hookId, Object alwaysNull) {
594     if (args[0] != null && args[1] != null && args[0].getClass() == args[1].getClass()) {
595       TraceDataFlowNativeCallbacks.traceGenericCmp(args[0], args[1], hookId);
596     }
597   }
598 }
599