xref: /aosp_15_r20/external/kotlinx.coroutines/kotlinx-coroutines-core/common/src/flow/internal/Combine.kt (revision 7a7160fed73afa6648ef8aa100d4a336fe921d9a)

<lambda>null1 @file:Suppress("UNCHECKED_CAST") // KT-32203
2 
3 package kotlinx.coroutines.flow.internal
4 
5 import kotlinx.coroutines.*
6 import kotlinx.coroutines.channels.*
7 import kotlinx.coroutines.flow.*
8 import kotlinx.coroutines.internal.*
9 private typealias Update = IndexedValue<Any?>
10 
11 @PublishedApi
12 internal suspend fun <R, T> FlowCollector<R>.combineInternal(
13     flows: Array<out Flow<T>>,
14     arrayFactory: () -> Array<T?>?, // Array factory is required to workaround array typing on JVM
15     transform: suspend FlowCollector<R>.(Array<T>) -> Unit
16 ): Unit = flowScope { // flow scope so any cancellation within the source flow will cancel the whole scope
17     val size = flows.size
18     if (size == 0) return@flowScope // bail-out for empty input
19     val latestValues = arrayOfNulls<Any?>(size)
20     latestValues.fill(UNINITIALIZED) // Smaller bytecode & faster than Array(size) { UNINITIALIZED }
21     val resultChannel = Channel<Update>(size)
22     val nonClosed = LocalAtomicInt(size)
23     var remainingAbsentValues = size
24     for (i in 0 until size) {
25         // Coroutine per flow that keeps track of its value and sends result to downstream
26         launch {
27             try {
28                 flows[i].collect { value ->
29                     resultChannel.send(Update(i, value))
30                     yield() // Emulate fairness, giving each flow chance to emit
31                 }
32             } finally {
33                 // Close the channel when there is no more flows
34                 if (nonClosed.decrementAndGet() == 0) {
35                     resultChannel.close()
36                 }
37             }
38         }
39     }
40 
41     /*
42      * Batch-receive optimization: read updates in batches, but bail-out
43      * as soon as we encountered two values from the same source
44      */
45     val lastReceivedEpoch = ByteArray(size)
46     var currentEpoch: Byte = 0
47     while (true) {
48         ++currentEpoch
49         // Start batch
50         // The very first receive in epoch should be suspending
51         var element = resultChannel.receiveCatching().getOrNull() ?: break // Channel is closed, nothing to do here
52         while (true) {
53             val index = element.index
54             // Update values
55             val previous = latestValues[index]
56             latestValues[index] = element.value
57             if (previous === UNINITIALIZED) --remainingAbsentValues
58             // Check epoch
59             // Received the second value from the same flow in the same epoch -- bail out
60             if (lastReceivedEpoch[index] == currentEpoch) break
61             lastReceivedEpoch[index] = currentEpoch
62             element = resultChannel.tryReceive().getOrNull() ?: break
63         }
64 
65         // Process batch result if there is enough data
66         if (remainingAbsentValues == 0) {
67             /*
68              * If arrayFactory returns null, then we can avoid array copy because
69              * it's our own safe transformer that immediately deconstructs the array
70              */
71             val results = arrayFactory()
72             if (results == null) {
73                 transform(latestValues as Array<T>)
74             } else {
75                 (latestValues as Array<T?>).copyInto(results)
76                 transform(results as Array<T>)
77             }
78         }
79     }
80 }
81 
zipImplnull82 internal fun <T1, T2, R> zipImpl(flow: Flow<T1>, flow2: Flow<T2>, transform: suspend (T1, T2) -> R): Flow<R> =
83     unsafeFlow {
84         coroutineScope {
85             val second = produce<Any> {
86                 flow2.collect { value ->
87                     return@collect channel.send(value ?: NULL)
88                 }
89             }
90 
91             /*
92              * This approach only works with rendezvous channel and is required to enforce correctness
93              * in the following scenario:
94              * ```
95              * val f1 = flow { emit(1); delay(Long.MAX_VALUE) }
96              * val f2 = flowOf(1)
97              * f1.zip(f2) { ... }
98              * ```
99              *
100              * Invariant: this clause is invoked only when all elements from the channel were processed (=> rendezvous restriction).
101              */
102             val collectJob = Job()
103             (second as SendChannel<*>).invokeOnClose {
104                 // Optimization to avoid AFE allocation when the other flow is done
105                 if (collectJob.isActive) collectJob.cancel(AbortFlowException(collectJob))
106             }
107 
108             try {
109                 /*
110                  * Non-trivial undispatched (because we are in the right context and there is no structured concurrency)
111                  * hierarchy:
112                  * -Outer coroutineScope that owns the whole zip process
113                  * - First flow is collected by the child of coroutineScope, collectJob.
114                  *    So it can be safely cancelled as soon as the second flow is done
115                  * - **But** the downstream MUST NOT be cancelled when the second flow is done,
116                  *    so we emit to downstream from coroutineScope job.
117                  * Typically, such hierarchy requires coroutine for collector that communicates
118                  * with coroutines scope via a channel, but it's way too expensive, so
119                  * we are using this trick instead.
120                  */
121                 val scopeContext = coroutineContext
122                 val cnt = threadContextElements(scopeContext)
123                 withContextUndispatched(coroutineContext + collectJob, Unit) {
124                     flow.collect { value ->
125                         withContextUndispatched(scopeContext, Unit, cnt) {
126                             val otherValue = second.receiveCatching().getOrElse {
127                                 throw it ?: AbortFlowException(collectJob)
128                             }
129                             emit(transform(value, NULL.unbox(otherValue)))
130                         }
131                     }
132                 }
133             } catch (e: AbortFlowException) {
134                 e.checkOwnership(owner = collectJob)
135             } finally {
136                 second.cancel()
137             }
138         }
139     }
140