xref: /aosp_15_r20/external/kotlinx.coroutines/test-utils/jvm/src/FieldWalker.kt (revision 7a7160fed73afa6648ef8aa100d4a336fe921d9a)

<lambda>null1 package kotlinx.coroutines.testing
2 
3 import java.lang.ref.*
4 import java.lang.reflect.*
5 import java.text.*
6 import java.util.*
7 import java.util.Collections.*
8 import java.util.concurrent.*
9 import java.util.concurrent.atomic.*
10 import java.util.concurrent.locks.*
11 import kotlin.test.*
12 
13 object FieldWalker {
14     sealed class Ref {
15         object RootRef : Ref()
16         class FieldRef(val parent: Any, val name: String) : Ref()
17         class ArrayRef(val parent: Any, val index: Int) : Ref()
18     }
19 
20     private val fieldsCache = HashMap<Class<*>, List<Field>>()
21 
22     init {
23         // excluded/terminal classes (don't walk them)
24         fieldsCache += listOf(
25             Any::class, String::class, Thread::class, Throwable::class, StackTraceElement::class,
26             WeakReference::class, ReferenceQueue::class, AbstractMap::class, Enum::class,
27             ReentrantLock::class, ReentrantReadWriteLock::class, SimpleDateFormat::class, ThreadPoolExecutor::class,
28             CountDownLatch::class,
29         )
30             .map { it.java }
31             .associateWith { emptyList() }
32     }
33 
34     /*
35      * Reflectively starts to walk through object graph and returns identity set of all reachable objects.
36      * Use [walkRefs] if you need a path from root for debugging.
37      */
38     public fun walk(root: Any?): Set<Any> = walkRefs(root, false).keys
39 
40     public fun assertReachableCount(expected: Int, root: Any?, rootStatics: Boolean = false, predicate: (Any) -> Boolean) {
41         val visited = walkRefs(root, rootStatics)
42         val actual = visited.keys.filter(predicate)
43         if (actual.size != expected) {
44             val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) }
45             assertEquals(
46                 expected, actual.size,
47                 "Unexpected number objects. Expected $expected, found ${actual.size}$textDump"
48             )
49         }
50     }
51 
52     /*
53      * Reflectively starts to walk through object graph and map to all the reached object to their path
54      * in from root. Use [showPath] do display a path if needed.
55      */
56     private fun walkRefs(root: Any?, rootStatics: Boolean): IdentityHashMap<Any, Ref> {
57         val visited = IdentityHashMap<Any, Ref>()
58         if (root == null) return visited
59         visited[root] = Ref.RootRef
60         val stack = ArrayDeque<Any>()
61         stack.addLast(root)
62         var statics = rootStatics
63         while (stack.isNotEmpty()) {
64             val element = stack.removeLast()
65             try {
66                 visit(element, visited, stack, statics)
67                 statics = false // only scan root static when asked
68             } catch (e: Exception) {
69                 error("Failed to visit element ${showPath(element, visited)}: $e")
70             }
71         }
72         return visited
73     }
74 
75     private fun showPath(element: Any, visited: Map<Any, Ref>): String {
76         val path = ArrayList<String>()
77         var cur = element
78         while (true) {
79             when (val ref = visited.getValue(cur)) {
80                 Ref.RootRef -> break
81                 is Ref.FieldRef -> {
82                     cur = ref.parent
83                     path += "|${ref.parent.javaClass.simpleName}::${ref.name}"
84                 }
85                 is Ref.ArrayRef -> {
86                     cur = ref.parent
87                     path += "[${ref.index}]"
88                 }
89                 else -> {
90                     // Nothing, kludge for IDE
91                 }
92             }
93         }
94         path.reverse()
95         return path.joinToString("")
96     }
97 
98     private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, statics: Boolean) {
99         val type = element.javaClass
100         when {
101             // Special code for arrays
102             type.isArray && !type.componentType.isPrimitive -> {
103                 @Suppress("UNCHECKED_CAST")
104                 val array = element as Array<Any?>
105                 array.forEachIndexed { index, value ->
106                     push(value, visited, stack) { Ref.ArrayRef(element, index) }
107                 }
108             }
109             // Special code for platform types that cannot be reflectively accessed on modern JDKs
110             type.name.startsWith("java.") && element is Collection<*> -> {
111                 element.forEachIndexed { index, value ->
112                     push(value, visited, stack) { Ref.ArrayRef(element, index) }
113                 }
114             }
115             type.name.startsWith("java.") && element is Map<*, *> -> {
116                 push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") }
117                 push(element.values, visited, stack) { Ref.FieldRef(element, "values") }
118             }
119             element is AtomicReference<*> -> {
120                 push(element.get(), visited, stack) { Ref.FieldRef(element, "value") }
121             }
122             element is AtomicReferenceArray<*> -> {
123                 for (index in 0 until element.length()) {
124                     push(element[index], visited, stack) { Ref.ArrayRef(element, index) }
125                 }
126             }
127             element is AtomicLongFieldUpdater<*> -> {
128                 /* filter it out here to suppress its subclasses too */
129             }
130             // All the other classes are reflectively scanned
131             else -> fields(type, statics).forEach { field ->
132                 push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) }
133                 // special case to scan Throwable cause (cannot get it reflectively)
134                 if (element is Throwable) {
135                     push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") }
136                 }
137             }
138         }
139     }
140 
141     private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) {
142         if (value != null && !visited.containsKey(value)) {
143             visited[value] = ref()
144             stack.addLast(value)
145         }
146     }
147 
148     private fun fields(type0: Class<*>, rootStatics: Boolean): List<Field> {
149         fieldsCache[type0]?.let { return it }
150         val result = ArrayList<Field>()
151         var type = type0
152         var statics = rootStatics
153         while (true) {
154             val fields = type.declaredFields.filter {
155                 !it.type.isPrimitive
156                     && (statics || !Modifier.isStatic(it.modifiers))
157                     && !(it.type.isArray && it.type.componentType.isPrimitive)
158                     && it.name != "previousOut" // System.out from TestBase that we store in a field to restore later
159             }
160             check(fields.isEmpty() || !type.name.startsWith("java.")) {
161                 """
162                     Trying to walk through JDK's '$type' will get into illegal reflective access on JDK 9+.
163                     Either modify your test to avoid usage of this class or update FieldWalker code to retrieve
164                     the captured state of this class without going through reflection (see how collections are handled).
165                 """.trimIndent()
166             }
167             fields.forEach { it.isAccessible = true } // make them all accessible
168             result.addAll(fields)
169             type = type.superclass
170             statics = false
171             val superFields = fieldsCache[type] // will stop at Any anyway
172             if (superFields != null) {
173                 result.addAll(superFields)
174                 break
175             }
176         }
177         fieldsCache[type0] = result
178         return result
179     }
180 }
181