1 package kotlinx.coroutines.test.internal
2
3 import kotlinx.atomicfu.*
4 import kotlinx.coroutines.*
5 import kotlinx.coroutines.test.*
6 import kotlin.coroutines.*
7
8 /**
9 * The testable main dispatcher used by kotlinx-coroutines-test.
10 * It is a [MainCoroutineDispatcher] that delegates all actions to a settable delegate.
11 */
12 internal class TestMainDispatcher(delegate: CoroutineDispatcher):
13 MainCoroutineDispatcher(),
14 Delay
15 {
16 private val mainDispatcher = delegate
17 private var delegate = NonConcurrentlyModifiable(mainDispatcher, "Dispatchers.Main")
18
19 private val delay
20 get() = delegate.value as? Delay ?: defaultDelay
21
22 override val immediate: MainCoroutineDispatcher
23 get() = (delegate.value as? MainCoroutineDispatcher)?.immediate ?: this
24
dispatchnull25 override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.value.dispatch(context, block)
26
27 override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.value.isDispatchNeeded(context)
28
29 override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.value.dispatchYield(context, block)
30
31 fun setDispatcher(dispatcher: CoroutineDispatcher) {
32 delegate.value = dispatcher
33 }
34
resetDispatchernull35 fun resetDispatcher() {
36 delegate.value = mainDispatcher
37 }
38
scheduleResumeAfterDelaynull39 override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) =
40 delay.scheduleResumeAfterDelay(timeMillis, continuation)
41
42 override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle =
43 delay.invokeOnTimeout(timeMillis, block, context)
44
45 companion object {
46 internal val currentTestDispatcher
47 get() = (Dispatchers.Main as? TestMainDispatcher)?.delegate?.value as? TestDispatcher
48
49 internal val currentTestScheduler
50 get() = currentTestDispatcher?.scheduler
51 }
52
53 /**
54 * A wrapper around a value that attempts to throw when writing happens concurrently with reading.
55 *
56 * The read operations never throw. Instead, the failures detected inside them will be remembered and thrown on the
57 * next modification.
58 */
59 private class NonConcurrentlyModifiable<T>(initialValue: T, private val name: String) {
60 private val reader: AtomicRef<Throwable?> = atomic(null) // last reader to attempt access
61 private val readers = atomic(0) // number of concurrent readers
62 private val writer: AtomicRef<Throwable?> = atomic(null) // writer currently performing value modification
63 private val exceptionWhenReading: AtomicRef<Throwable?> = atomic(null) // exception from reading
64 private val _value = atomic(initialValue) // the backing field for the value
65
concurrentWWnull66 private fun concurrentWW(location: Throwable) = IllegalStateException("$name is modified concurrently", location)
67 private fun concurrentRW(location: Throwable) = IllegalStateException("$name is used concurrently with setting it", location)
68
69 var value: T
70 get() {
71 reader.value = Throwable("reader location")
72 readers.incrementAndGet()
73 writer.value?.let { exceptionWhenReading.value = concurrentRW(it) }
74 val result = _value.value
75 readers.decrementAndGet()
76 return result
77 }
78 set(value) {
<lambda>null79 exceptionWhenReading.getAndSet(null)?.let { throw it }
<lambda>null80 if (readers.value != 0) reader.value?.let { throw concurrentRW(it) }
81 val writerLocation = Throwable("other writer location")
<lambda>null82 writer.getAndSet(writerLocation)?.let { throw concurrentWW(it) }
83 _value.value = value
84 writer.compareAndSet(writerLocation, null)
<lambda>null85 if (readers.value != 0) reader.value?.let { throw concurrentRW(it) }
86 }
87 }
88 }
89
90 @Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") // do not remove the INVISIBLE_REFERENCE suppression: required in K2
91 private val defaultDelay
92 inline get() = DefaultDelay
93
94 @Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") // do not remove the INVISIBLE_REFERENCE suppression: required in K2
getTestMainDispatchernull95 internal expect fun Dispatchers.getTestMainDispatcher(): TestMainDispatcher
96