1 package kotlinx.coroutines.test.internal
2 
3 import kotlinx.atomicfu.*
4 import kotlinx.coroutines.*
5 import kotlinx.coroutines.test.*
6 import kotlin.coroutines.*
7 
8 /**
9  * The testable main dispatcher used by kotlinx-coroutines-test.
10  * It is a [MainCoroutineDispatcher] that delegates all actions to a settable delegate.
11  */
12 internal class TestMainDispatcher(delegate: CoroutineDispatcher):
13     MainCoroutineDispatcher(),
14     Delay
15 {
16     private val mainDispatcher = delegate
17     private var delegate = NonConcurrentlyModifiable(mainDispatcher, "Dispatchers.Main")
18 
19     private val delay
20         get() = delegate.value as? Delay ?: defaultDelay
21 
22     override val immediate: MainCoroutineDispatcher
23         get() = (delegate.value as? MainCoroutineDispatcher)?.immediate ?: this
24 
dispatchnull25     override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.value.dispatch(context, block)
26 
27     override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.value.isDispatchNeeded(context)
28 
29     override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.value.dispatchYield(context, block)
30 
31     fun setDispatcher(dispatcher: CoroutineDispatcher) {
32         delegate.value = dispatcher
33     }
34 
resetDispatchernull35     fun resetDispatcher() {
36         delegate.value = mainDispatcher
37     }
38 
scheduleResumeAfterDelaynull39     override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) =
40         delay.scheduleResumeAfterDelay(timeMillis, continuation)
41 
42     override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle =
43         delay.invokeOnTimeout(timeMillis, block, context)
44 
45     companion object {
46         internal val currentTestDispatcher
47             get() = (Dispatchers.Main as? TestMainDispatcher)?.delegate?.value as? TestDispatcher
48 
49         internal val currentTestScheduler
50             get() = currentTestDispatcher?.scheduler
51     }
52 
53     /**
54      * A wrapper around a value that attempts to throw when writing happens concurrently with reading.
55      *
56      * The read operations never throw. Instead, the failures detected inside them will be remembered and thrown on the
57      * next modification.
58      */
59     private class NonConcurrentlyModifiable<T>(initialValue: T, private val name: String) {
60         private val reader: AtomicRef<Throwable?> = atomic(null) // last reader to attempt access
61         private val readers = atomic(0) // number of concurrent readers
62         private val writer: AtomicRef<Throwable?> = atomic(null) // writer currently performing value modification
63         private val exceptionWhenReading: AtomicRef<Throwable?> = atomic(null) // exception from reading
64         private val _value = atomic(initialValue) // the backing field for the value
65 
concurrentWWnull66         private fun concurrentWW(location: Throwable) = IllegalStateException("$name is modified concurrently", location)
67         private fun concurrentRW(location: Throwable) = IllegalStateException("$name is used concurrently with setting it", location)
68 
69         var value: T
70             get() {
71                 reader.value = Throwable("reader location")
72                 readers.incrementAndGet()
73                 writer.value?.let { exceptionWhenReading.value = concurrentRW(it) }
74                 val result = _value.value
75                 readers.decrementAndGet()
76                 return result
77             }
78             set(value) {
<lambda>null79                 exceptionWhenReading.getAndSet(null)?.let { throw it }
<lambda>null80                 if (readers.value != 0) reader.value?.let { throw concurrentRW(it) }
81                 val writerLocation = Throwable("other writer location")
<lambda>null82                 writer.getAndSet(writerLocation)?.let { throw concurrentWW(it) }
83                 _value.value = value
84                 writer.compareAndSet(writerLocation, null)
<lambda>null85                 if (readers.value != 0) reader.value?.let { throw concurrentRW(it) }
86             }
87     }
88 }
89 
90 @Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") // do not remove the INVISIBLE_REFERENCE suppression: required in K2
91 private val defaultDelay
92     inline get() = DefaultDelay
93 
94 @Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") // do not remove the INVISIBLE_REFERENCE suppression: required in K2
getTestMainDispatchernull95 internal expect fun Dispatchers.getTestMainDispatcher(): TestMainDispatcher
96