1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.test.tracing.coroutines
18
19 import android.os.Looper
20 import android.platform.test.flag.junit.SetFlagsRule
21 import androidx.test.ext.junit.runners.AndroidJUnit4
22 import com.android.app.tracing.coroutines.CoroutineTraceName
23 import com.android.test.tracing.coroutines.util.FakeTraceState
24 import com.android.test.tracing.coroutines.util.FakeTraceState.getOpenTraceSectionsOnCurrentThread
25 import com.android.test.tracing.coroutines.util.ShadowTrace
26 import java.io.PrintWriter
27 import java.io.StringWriter
28 import java.util.concurrent.TimeUnit.MILLISECONDS
29 import java.util.concurrent.atomic.AtomicInteger
30 import kotlin.coroutines.CoroutineContext
31 import kotlin.coroutines.EmptyCoroutineContext
32 import kotlinx.coroutines.CancellationException
33 import kotlinx.coroutines.CoroutineExceptionHandler
34 import kotlinx.coroutines.CoroutineScope
35 import kotlinx.coroutines.Dispatchers
36 import kotlinx.coroutines.Job
37 import kotlinx.coroutines.delay
38 import kotlinx.coroutines.launch
39 import org.junit.After
40 import org.junit.Assert.assertEquals
41 import org.junit.Assert.assertTrue
42 import org.junit.Before
43 import org.junit.ClassRule
44 import org.junit.Rule
45 import org.junit.runner.RunWith
46 import org.robolectric.Shadows.shadowOf
47 import org.robolectric.annotation.Config
48 import org.robolectric.shadows.ShadowLooper
49
50 class InvalidTraceStateException(message: String) : Exception(message)
51
52 @RunWith(AndroidJUnit4::class)
53 @Config(shadows = [ShadowTrace::class])
54 open class TestBase {
55
56 companion object {
57 @JvmField
58 @ClassRule
59 val setFlagsClassRule: SetFlagsRule.ClassRule = SetFlagsRule.ClassRule()
60 }
61
62 @JvmField @Rule val setFlagsRule = SetFlagsRule()
63
64 private val eventCounter = AtomicInteger(0)
65 private val finalEvent = AtomicInteger(INVALID_EVENT)
66 private var expectedExceptions = false
67 private lateinit var allExceptions: MutableList<Throwable>
68 private lateinit var shadowLooper: ShadowLooper
69 private lateinit var mainTraceScope: CoroutineScope
70
71 open val extraCoroutineContext: CoroutineContext
72 get() = EmptyCoroutineContext
73
74 @Before
75 fun setup() {
76 FakeTraceState.isTracingEnabled = true
77 eventCounter.set(0)
78 allExceptions = mutableListOf()
79 shadowLooper = shadowOf(Looper.getMainLooper())
80 mainTraceScope = CoroutineScope(Dispatchers.Main + extraCoroutineContext)
81 }
82
83 @After
84 fun tearDown() {
85 val sw = StringWriter()
86 val pw = PrintWriter(sw)
87 allExceptions.forEach { it.printStackTrace(pw) }
88 assertTrue("Test failed due to incorrect trace sections\n$sw", allExceptions.isEmpty())
89
90 val lastEvent = eventCounter.get()
91 assertTrue(
92 "`finish()` was never called. Last seen event was #$lastEvent",
93 lastEvent == FINAL_EVENT || lastEvent == 0 || expectedExceptions,
94 )
95 }
96
97 protected fun runTest(
98 expectedException: ((Throwable) -> Boolean)? = null,
99 block: suspend CoroutineScope.() -> Unit,
100 ) {
101 var foundExpectedException = false
102 if (expectedException != null) expectedExceptions = true
103 mainTraceScope.launch(
104 block = block,
105 context =
106 CoroutineExceptionHandler { _, e ->
107 if (e is CancellationException) return@CoroutineExceptionHandler // ignore
108 if (expectedException != null && expectedException(e)) {
109 foundExpectedException = true
110 return@CoroutineExceptionHandler // ignore
111 }
112 allExceptions.add(e)
113 },
114 )
115
116 for (n in 0..1000) {
117 shadowLooper.idleFor(1, MILLISECONDS)
118 }
119
120 val names = mutableListOf<String?>()
121 var numChildren = 0
122 mainTraceScope.coroutineContext[Job]?.children?.forEach { it ->
123 names.add(it[CoroutineTraceName]?.name)
124 numChildren++
125 }
126
127 val allNames =
128 names.joinToString(prefix = "{ ", separator = ", ", postfix = " }") {
129 it?.let { "\"$it\" " } ?: "unnamed"
130 }
131 assertEquals(
132 "The main test scope still has $numChildren running jobs: $allNames.",
133 0,
134 numChildren,
135 )
136 if (expectedExceptions) {
137 assertTrue("Expected exceptions, but none were thrown", foundExpectedException)
138 }
139 }
140
141 private fun logInvalidTraceState(message: String) {
142 allExceptions.add(InvalidTraceStateException(message))
143 }
144
145 /**
146 * Same as [expect], but also call [delay] for 1ms, calling [expect] before and after the
147 * suspension point.
148 */
149 protected suspend fun expectD(vararg expectedOpenTraceSections: String) {
150 expect(*expectedOpenTraceSections)
151 delay(1)
152 expect(*expectedOpenTraceSections)
153 }
154
155 /**
156 * Same as [expect], but also call [delay] for 1ms, calling [expect] before and after the
157 * suspension point.
158 */
159 protected suspend fun expectD(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
160 expect(expectedEvent, *expectedOpenTraceSections)
161 delay(1)
162 expect(*expectedOpenTraceSections)
163 }
164
165 protected fun expectEndsWith(vararg expectedOpenTraceSections: String) {
166 // Inspect trace output to the fake used for recording android.os.Trace API calls:
167 val actualSections = getOpenTraceSectionsOnCurrentThread()
168 if (expectedOpenTraceSections.size <= actualSections.size) {
169 val lastSections =
170 actualSections.takeLast(expectedOpenTraceSections.size).toTypedArray()
171 assertTraceSectionsEquals(expectedOpenTraceSections, null, lastSections, null)
172 } else {
173 logInvalidTraceState(
174 "Invalid length: expected size (${expectedOpenTraceSections.size}) <= actual size (${actualSections.size})"
175 )
176 }
177 }
178
179 protected fun expectEvent(expectedEvent: Collection<Int>): Int {
180 val previousEvent = eventCounter.getAndAdd(1)
181 val currentEvent = previousEvent + 1
182 if (!expectedEvent.contains(currentEvent)) {
183 logInvalidTraceState(
184 if (previousEvent == FINAL_EVENT) {
185 "Expected event ${expectedEvent.prettyPrintList()}, but finish() was already called"
186 } else {
187 "Expected event ${expectedEvent.prettyPrintList()}," +
188 " but the event counter is currently at #$currentEvent"
189 }
190 )
191 }
192 return currentEvent
193 }
194
195 internal fun expect(vararg expectedOpenTraceSections: String) {
196 expect(null, *expectedOpenTraceSections)
197 }
198
199 internal fun expect(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
200 expect(listOf(expectedEvent), *expectedOpenTraceSections)
201 }
202
203 /**
204 * Checks the currently active trace sections on the current thread, and optionally checks the
205 * order of operations if [expectedEvent] is not null.
206 */
207 internal fun expect(possibleEventPos: List<Int>?, vararg expectedOpenTraceSections: String) {
208 var currentEvent: Int? = null
209 if (possibleEventPos != null) {
210 currentEvent = expectEvent(possibleEventPos)
211 }
212 val actualOpenSections = getOpenTraceSectionsOnCurrentThread()
213 assertTraceSectionsEquals(
214 expectedOpenTraceSections,
215 possibleEventPos,
216 actualOpenSections,
217 currentEvent,
218 )
219 }
220
221 private fun assertTraceSectionsEquals(
222 expectedOpenTraceSections: Array<out String>,
223 expectedEvent: List<Int>?,
224 actualOpenSections: Array<String>,
225 actualEvent: Int?,
226 ) {
227 val expectedSize = expectedOpenTraceSections.size
228 val actualSize = actualOpenSections.size
229 if (expectedSize != actualSize) {
230 logInvalidTraceState(
231 createFailureMessage(
232 expectedOpenTraceSections,
233 expectedEvent,
234 actualOpenSections,
235 actualEvent,
236 "Size mismatch, expected size $expectedSize but was size $actualSize",
237 )
238 )
239 } else {
240 expectedOpenTraceSections.forEachIndexed { n, expectedTrace ->
241 val actualTrace = actualOpenSections[n]
242 val expected = expectedTrace.substringBefore(";")
243 val actual = actualTrace.substringBefore(";")
244 if (expected != actual) {
245 logInvalidTraceState(
246 createFailureMessage(
247 expectedOpenTraceSections,
248 expectedEvent,
249 actualOpenSections,
250 actualEvent,
251 "Differed at index #$n, expected \"$expected\" but was \"$actual\"",
252 )
253 )
254 return@forEachIndexed
255 }
256 }
257 }
258 }
259
260 private fun createFailureMessage(
261 expectedOpenTraceSections: Array<out String>,
262 expectedEventNumber: List<Int>?,
263 actualOpenSections: Array<String>,
264 actualEventNumber: Int?,
265 extraMessage: String,
266 ): String {
267 val locationMarker =
268 if (expectedEventNumber == null || actualEventNumber == null) ""
269 else if (expectedEventNumber.contains(actualEventNumber))
270 " at event #$actualEventNumber"
271 else
272 ", expected event ${expectedEventNumber.prettyPrintList()}, actual event #$actualEventNumber"
273 return """
274 Incorrect trace$locationMarker. $extraMessage
275 Expected : {${expectedOpenTraceSections.prettyPrintList()}}
276 Actual : {${actualOpenSections.prettyPrintList()}}
277 """
278 .trimIndent()
279 }
280
281 /** Same as [expect], except that no more [expect] statements can be called after it. */
282 protected fun finish(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
283 finalEvent.compareAndSet(INVALID_EVENT, expectedEvent)
284 val previousEvent = eventCounter.getAndSet(FINAL_EVENT)
285 val currentEvent = previousEvent + 1
286 if (expectedEvent != currentEvent) {
287 logInvalidTraceState(
288 "Expected to finish with event #$expectedEvent, but " +
289 if (previousEvent == FINAL_EVENT)
290 "finish() was already called with event #${finalEvent.get()}"
291 else "the event counter is currently at #$currentEvent"
292 )
293 }
294 assertTraceSectionsEquals(
295 expectedOpenTraceSections,
296 listOf(expectedEvent),
297 getOpenTraceSectionsOnCurrentThread(),
298 currentEvent,
299 )
300 }
301 }
302
303 private const val INVALID_EVENT = -1
304
305 private const val FINAL_EVENT = Int.MIN_VALUE
306
prettyPrintListnull307 private fun Collection<Int>.prettyPrintList(): String {
308 return if (isEmpty()) ""
309 else if (size == 1) "#${iterator().next()}"
310 else {
311 "{${
312 toList().joinToString(
313 separator = ", #",
314 prefix = "#",
315 postfix = "",
316 ) { it.toString() }
317 }}"
318 }
319 }
320
Arraynull321 private fun Array<out String>.prettyPrintList(): String {
322 return if (isEmpty()) ""
323 else
324 toList().joinToString(separator = "\", \"", prefix = "\"", postfix = "\"") {
325 it.substringBefore(";")
326 }
327 }
328