1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.test.tracing.coroutines
18 
19 import android.os.Looper
20 import android.platform.test.flag.junit.SetFlagsRule
21 import androidx.test.ext.junit.runners.AndroidJUnit4
22 import com.android.app.tracing.coroutines.CoroutineTraceName
23 import com.android.test.tracing.coroutines.util.FakeTraceState
24 import com.android.test.tracing.coroutines.util.FakeTraceState.getOpenTraceSectionsOnCurrentThread
25 import com.android.test.tracing.coroutines.util.ShadowTrace
26 import java.io.PrintWriter
27 import java.io.StringWriter
28 import java.util.concurrent.TimeUnit.MILLISECONDS
29 import java.util.concurrent.atomic.AtomicInteger
30 import kotlin.coroutines.CoroutineContext
31 import kotlin.coroutines.EmptyCoroutineContext
32 import kotlinx.coroutines.CancellationException
33 import kotlinx.coroutines.CoroutineExceptionHandler
34 import kotlinx.coroutines.CoroutineScope
35 import kotlinx.coroutines.Dispatchers
36 import kotlinx.coroutines.Job
37 import kotlinx.coroutines.delay
38 import kotlinx.coroutines.launch
39 import org.junit.After
40 import org.junit.Assert.assertEquals
41 import org.junit.Assert.assertTrue
42 import org.junit.Before
43 import org.junit.ClassRule
44 import org.junit.Rule
45 import org.junit.runner.RunWith
46 import org.robolectric.Shadows.shadowOf
47 import org.robolectric.annotation.Config
48 import org.robolectric.shadows.ShadowLooper
49 
50 class InvalidTraceStateException(message: String) : Exception(message)
51 
52 @RunWith(AndroidJUnit4::class)
53 @Config(shadows = [ShadowTrace::class])
54 open class TestBase {
55 
56     companion object {
57         @JvmField
58         @ClassRule
59         val setFlagsClassRule: SetFlagsRule.ClassRule = SetFlagsRule.ClassRule()
60     }
61 
62     @JvmField @Rule val setFlagsRule = SetFlagsRule()
63 
64     private val eventCounter = AtomicInteger(0)
65     private val finalEvent = AtomicInteger(INVALID_EVENT)
66     private var expectedExceptions = false
67     private lateinit var allExceptions: MutableList<Throwable>
68     private lateinit var shadowLooper: ShadowLooper
69     private lateinit var mainTraceScope: CoroutineScope
70 
71     open val extraCoroutineContext: CoroutineContext
72         get() = EmptyCoroutineContext
73 
74     @Before
75     fun setup() {
76         FakeTraceState.isTracingEnabled = true
77         eventCounter.set(0)
78         allExceptions = mutableListOf()
79         shadowLooper = shadowOf(Looper.getMainLooper())
80         mainTraceScope = CoroutineScope(Dispatchers.Main + extraCoroutineContext)
81     }
82 
83     @After
84     fun tearDown() {
85         val sw = StringWriter()
86         val pw = PrintWriter(sw)
87         allExceptions.forEach { it.printStackTrace(pw) }
88         assertTrue("Test failed due to incorrect trace sections\n$sw", allExceptions.isEmpty())
89 
90         val lastEvent = eventCounter.get()
91         assertTrue(
92             "`finish()` was never called. Last seen event was #$lastEvent",
93             lastEvent == FINAL_EVENT || lastEvent == 0 || expectedExceptions,
94         )
95     }
96 
97     protected fun runTest(
98         expectedException: ((Throwable) -> Boolean)? = null,
99         block: suspend CoroutineScope.() -> Unit,
100     ) {
101         var foundExpectedException = false
102         if (expectedException != null) expectedExceptions = true
103         mainTraceScope.launch(
104             block = block,
105             context =
106                 CoroutineExceptionHandler { _, e ->
107                     if (e is CancellationException) return@CoroutineExceptionHandler // ignore
108                     if (expectedException != null && expectedException(e)) {
109                         foundExpectedException = true
110                         return@CoroutineExceptionHandler // ignore
111                     }
112                     allExceptions.add(e)
113                 },
114         )
115 
116         for (n in 0..1000) {
117             shadowLooper.idleFor(1, MILLISECONDS)
118         }
119 
120         val names = mutableListOf<String?>()
121         var numChildren = 0
122         mainTraceScope.coroutineContext[Job]?.children?.forEach { it ->
123             names.add(it[CoroutineTraceName]?.name)
124             numChildren++
125         }
126 
127         val allNames =
128             names.joinToString(prefix = "{ ", separator = ", ", postfix = " }") {
129                 it?.let { "\"$it\" " } ?: "unnamed"
130             }
131         assertEquals(
132             "The main test scope still has $numChildren running jobs: $allNames.",
133             0,
134             numChildren,
135         )
136         if (expectedExceptions) {
137             assertTrue("Expected exceptions, but none were thrown", foundExpectedException)
138         }
139     }
140 
141     private fun logInvalidTraceState(message: String) {
142         allExceptions.add(InvalidTraceStateException(message))
143     }
144 
145     /**
146      * Same as [expect], but also call [delay] for 1ms, calling [expect] before and after the
147      * suspension point.
148      */
149     protected suspend fun expectD(vararg expectedOpenTraceSections: String) {
150         expect(*expectedOpenTraceSections)
151         delay(1)
152         expect(*expectedOpenTraceSections)
153     }
154 
155     /**
156      * Same as [expect], but also call [delay] for 1ms, calling [expect] before and after the
157      * suspension point.
158      */
159     protected suspend fun expectD(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
160         expect(expectedEvent, *expectedOpenTraceSections)
161         delay(1)
162         expect(*expectedOpenTraceSections)
163     }
164 
165     protected fun expectEndsWith(vararg expectedOpenTraceSections: String) {
166         // Inspect trace output to the fake used for recording android.os.Trace API calls:
167         val actualSections = getOpenTraceSectionsOnCurrentThread()
168         if (expectedOpenTraceSections.size <= actualSections.size) {
169             val lastSections =
170                 actualSections.takeLast(expectedOpenTraceSections.size).toTypedArray()
171             assertTraceSectionsEquals(expectedOpenTraceSections, null, lastSections, null)
172         } else {
173             logInvalidTraceState(
174                 "Invalid length: expected size (${expectedOpenTraceSections.size}) <= actual size (${actualSections.size})"
175             )
176         }
177     }
178 
179     protected fun expectEvent(expectedEvent: Collection<Int>): Int {
180         val previousEvent = eventCounter.getAndAdd(1)
181         val currentEvent = previousEvent + 1
182         if (!expectedEvent.contains(currentEvent)) {
183             logInvalidTraceState(
184                 if (previousEvent == FINAL_EVENT) {
185                     "Expected event ${expectedEvent.prettyPrintList()}, but finish() was already called"
186                 } else {
187                     "Expected event ${expectedEvent.prettyPrintList()}," +
188                         " but the event counter is currently at #$currentEvent"
189                 }
190             )
191         }
192         return currentEvent
193     }
194 
195     internal fun expect(vararg expectedOpenTraceSections: String) {
196         expect(null, *expectedOpenTraceSections)
197     }
198 
199     internal fun expect(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
200         expect(listOf(expectedEvent), *expectedOpenTraceSections)
201     }
202 
203     /**
204      * Checks the currently active trace sections on the current thread, and optionally checks the
205      * order of operations if [expectedEvent] is not null.
206      */
207     internal fun expect(possibleEventPos: List<Int>?, vararg expectedOpenTraceSections: String) {
208         var currentEvent: Int? = null
209         if (possibleEventPos != null) {
210             currentEvent = expectEvent(possibleEventPos)
211         }
212         val actualOpenSections = getOpenTraceSectionsOnCurrentThread()
213         assertTraceSectionsEquals(
214             expectedOpenTraceSections,
215             possibleEventPos,
216             actualOpenSections,
217             currentEvent,
218         )
219     }
220 
221     private fun assertTraceSectionsEquals(
222         expectedOpenTraceSections: Array<out String>,
223         expectedEvent: List<Int>?,
224         actualOpenSections: Array<String>,
225         actualEvent: Int?,
226     ) {
227         val expectedSize = expectedOpenTraceSections.size
228         val actualSize = actualOpenSections.size
229         if (expectedSize != actualSize) {
230             logInvalidTraceState(
231                 createFailureMessage(
232                     expectedOpenTraceSections,
233                     expectedEvent,
234                     actualOpenSections,
235                     actualEvent,
236                     "Size mismatch, expected size $expectedSize but was size $actualSize",
237                 )
238             )
239         } else {
240             expectedOpenTraceSections.forEachIndexed { n, expectedTrace ->
241                 val actualTrace = actualOpenSections[n]
242                 val expected = expectedTrace.substringBefore(";")
243                 val actual = actualTrace.substringBefore(";")
244                 if (expected != actual) {
245                     logInvalidTraceState(
246                         createFailureMessage(
247                             expectedOpenTraceSections,
248                             expectedEvent,
249                             actualOpenSections,
250                             actualEvent,
251                             "Differed at index #$n, expected \"$expected\" but was \"$actual\"",
252                         )
253                     )
254                     return@forEachIndexed
255                 }
256             }
257         }
258     }
259 
260     private fun createFailureMessage(
261         expectedOpenTraceSections: Array<out String>,
262         expectedEventNumber: List<Int>?,
263         actualOpenSections: Array<String>,
264         actualEventNumber: Int?,
265         extraMessage: String,
266     ): String {
267         val locationMarker =
268             if (expectedEventNumber == null || actualEventNumber == null) ""
269             else if (expectedEventNumber.contains(actualEventNumber))
270                 " at event #$actualEventNumber"
271             else
272                 ", expected event ${expectedEventNumber.prettyPrintList()}, actual event #$actualEventNumber"
273         return """
274                 Incorrect trace$locationMarker. $extraMessage
275                   Expected : {${expectedOpenTraceSections.prettyPrintList()}}
276                   Actual   : {${actualOpenSections.prettyPrintList()}}
277                 """
278             .trimIndent()
279     }
280 
281     /** Same as [expect], except that no more [expect] statements can be called after it. */
282     protected fun finish(expectedEvent: Int, vararg expectedOpenTraceSections: String) {
283         finalEvent.compareAndSet(INVALID_EVENT, expectedEvent)
284         val previousEvent = eventCounter.getAndSet(FINAL_EVENT)
285         val currentEvent = previousEvent + 1
286         if (expectedEvent != currentEvent) {
287             logInvalidTraceState(
288                 "Expected to finish with event #$expectedEvent, but " +
289                     if (previousEvent == FINAL_EVENT)
290                         "finish() was already called with event #${finalEvent.get()}"
291                     else "the event counter is currently at #$currentEvent"
292             )
293         }
294         assertTraceSectionsEquals(
295             expectedOpenTraceSections,
296             listOf(expectedEvent),
297             getOpenTraceSectionsOnCurrentThread(),
298             currentEvent,
299         )
300     }
301 }
302 
303 private const val INVALID_EVENT = -1
304 
305 private const val FINAL_EVENT = Int.MIN_VALUE
306 
prettyPrintListnull307 private fun Collection<Int>.prettyPrintList(): String {
308     return if (isEmpty()) ""
309     else if (size == 1) "#${iterator().next()}"
310     else {
311         "{${
312             toList().joinToString(
313                 separator = ", #",
314                 prefix = "#",
315                 postfix = "",
316             ) { it.toString() }
317         }}"
318     }
319 }
320 
Arraynull321 private fun Array<out String>.prettyPrintList(): String {
322     return if (isEmpty()) ""
323     else
324         toList().joinToString(separator = "\", \"", prefix = "\"", postfix = "\"") {
325             it.substringBefore(";")
326         }
327 }
328