<lambda>null1 package kotlinx.coroutines.flow.internal
2 
3 import kotlinx.coroutines.*
4 import kotlinx.coroutines.channels.*
5 import kotlinx.coroutines.flow.*
6 import kotlinx.coroutines.internal.*
7 import kotlin.coroutines.*
8 import kotlin.coroutines.intrinsics.*
9 import kotlin.jvm.*
10 
11 internal fun <T> Flow<T>.asChannelFlow(): ChannelFlow<T> =
12     this as? ChannelFlow ?: ChannelFlowOperatorImpl(this)
13 
14 /**
15  * Operators that can fuse with **downstream** [buffer] and [flowOn] operators implement this interface.
16  *
17  * @suppress **This an internal API and should not be used from general code.**
18  */
19 @InternalCoroutinesApi
20 public interface FusibleFlow<T> : Flow<T> {
21     /**
22      * This function is called by [flowOn] (with context) and [buffer] (with capacity) operators
23      * that are applied to this flow. Should not be used with [capacity] of [Channel.CONFLATED]
24      * (it shall be desugared to `capacity = 0, onBufferOverflow = DROP_OLDEST`).
25      */
26     public fun fuse(
27         context: CoroutineContext = EmptyCoroutineContext,
28         capacity: Int = Channel.OPTIONAL_CHANNEL,
29         onBufferOverflow: BufferOverflow = BufferOverflow.SUSPEND
30     ): Flow<T>
31 }
32 
33 /**
34  * Operators that use channels as their "output" extend this `ChannelFlow` and are always fused with each other.
35  * This class servers as a skeleton implementation of [FusibleFlow] and provides other cross-cutting
36  * methods like ability to [produceIn] the corresponding flow, thus making it
37  * possible to directly use the backing channel if it exists (hence the `ChannelFlow` name).
38  *
39  * @suppress **This an internal API and should not be used from general code.**
40  */
41 @InternalCoroutinesApi
42 public abstract class ChannelFlow<T>(
43     // upstream context
44     @JvmField public val context: CoroutineContext,
45     // buffer capacity between upstream and downstream context
46     @JvmField public val capacity: Int,
47     // buffer overflow strategy
48     @JvmField public val onBufferOverflow: BufferOverflow
49 ) : FusibleFlow<T> {
50     init {
<lambda>null51         assert { capacity != Channel.CONFLATED } // CONFLATED must be desugared to 0, DROP_OLDEST by callers
52     }
53 
54     // shared code to create a suspend lambda from collectTo function in one place
55     internal val collectToFun: suspend (ProducerScope<T>) -> Unit
<lambda>null56         get() = { collectTo(it) }
57 
58     internal val produceCapacity: Int
59         get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity
60 
61     /**
62      * When this [ChannelFlow] implementation can work without a channel (supports [Channel.OPTIONAL_CHANNEL]),
63      * then it should return a non-null value from this function, so that a caller can use it without the effect of
64      * additional [flowOn] and [buffer] operators, by incorporating its
65      * [context], [capacity], and [onBufferOverflow] into its own implementation.
66      */
dropChannelOperatorsnull67     public open fun dropChannelOperators(): Flow<T>? = null
68 
69     public override fun fuse(context: CoroutineContext, capacity: Int, onBufferOverflow: BufferOverflow): Flow<T> {
70         assert { capacity != Channel.CONFLATED } // CONFLATED must be desugared to (0, DROP_OLDEST) by callers
71         // note: previous upstream context (specified before) takes precedence
72         val newContext = context + this.context
73         val newCapacity: Int
74         val newOverflow: BufferOverflow
75         if (onBufferOverflow != BufferOverflow.SUSPEND) {
76             // this additional buffer never suspends => overwrite preceding buffering configuration
77             newCapacity = capacity
78             newOverflow = onBufferOverflow
79         } else {
80             // combine capacities, keep previous overflow strategy
81             newCapacity = when {
82                 this.capacity == Channel.OPTIONAL_CHANNEL -> capacity
83                 capacity == Channel.OPTIONAL_CHANNEL -> this.capacity
84                 this.capacity == Channel.BUFFERED -> capacity
85                 capacity == Channel.BUFFERED -> this.capacity
86                 else -> {
87                     // sanity checks
88                     assert { this.capacity >= 0 }
89                     assert { capacity >= 0 }
90                     // combine capacities clamping to UNLIMITED on overflow
91                     val sum = this.capacity + capacity
92                     if (sum >= 0) sum else Channel.UNLIMITED // unlimited on int overflow
93                 }
94             }
95             newOverflow = this.onBufferOverflow
96         }
97         if (newContext == this.context && newCapacity == this.capacity && newOverflow == this.onBufferOverflow)
98             return this
99         return create(newContext, newCapacity, newOverflow)
100     }
101 
createnull102     protected abstract fun create(context: CoroutineContext, capacity: Int, onBufferOverflow: BufferOverflow): ChannelFlow<T>
103 
104     protected abstract suspend fun collectTo(scope: ProducerScope<T>)
105 
106     /**
107      * Here we use ATOMIC start for a reason (#1825).
108      * NB: [produceImpl] is used for [flowOn].
109      * For non-atomic start it is possible to observe the situation,
110      * where the pipeline after the [flowOn] call successfully executes (mostly, its `onCompletion`)
111      * handlers, while the pipeline before does not, because it was cancelled during its dispatch.
112      * Thus `onCompletion` and `finally` blocks won't be executed and it may lead to a different kinds of memory leaks.
113      */
114     public open fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> =
115         scope.produce(context, produceCapacity, onBufferOverflow, start = CoroutineStart.ATOMIC, block = collectToFun)
116 
117     override suspend fun collect(collector: FlowCollector<T>): Unit =
118         coroutineScope {
119             collector.emitAll(produceImpl(this))
120         }
121 
additionalToStringPropsnull122     protected open fun additionalToStringProps(): String? = null
123 
124     // debug toString
125     override fun toString(): String {
126         val props = ArrayList<String>(4)
127         additionalToStringProps()?.let { props.add(it) }
128         if (context !== EmptyCoroutineContext) props.add("context=$context")
129         if (capacity != Channel.OPTIONAL_CHANNEL) props.add("capacity=$capacity")
130         if (onBufferOverflow != BufferOverflow.SUSPEND) props.add("onBufferOverflow=$onBufferOverflow")
131         return "$classSimpleName[${props.joinToString(", ")}]"
132     }
133 }
134 
135 // ChannelFlow implementation that operates on another flow before it
136 internal abstract class ChannelFlowOperator<S, T>(
137     @JvmField protected val flow: Flow<S>,
138     context: CoroutineContext,
139     capacity: Int,
140     onBufferOverflow: BufferOverflow
141 ) : ChannelFlow<T>(context, capacity, onBufferOverflow) {
flowCollectnull142     protected abstract suspend fun flowCollect(collector: FlowCollector<T>)
143 
144     // Changes collecting context upstream to the specified newContext, while collecting in the original context
145     private suspend fun collectWithContextUndispatched(collector: FlowCollector<T>, newContext: CoroutineContext) {
146         val originalContextCollector = collector.withUndispatchedContextCollector(coroutineContext)
147         // invoke flowCollect(originalContextCollector) in the newContext
148         return withContextUndispatched(newContext, block = { flowCollect(it) }, value = originalContextCollector)
149     }
150 
151     // Slow path when output channel is required
collectTonull152     protected override suspend fun collectTo(scope: ProducerScope<T>) =
153         flowCollect(SendingCollector(scope))
154 
155     // Optimizations for fast-path when channel creation is optional
156     override suspend fun collect(collector: FlowCollector<T>) {
157         // Fast-path: When channel creation is optional (flowOn/flowWith operators without buffer)
158         if (capacity == Channel.OPTIONAL_CHANNEL) {
159             val collectContext = coroutineContext
160             val newContext = collectContext.newCoroutineContext(context) // compute resulting collect context
161             // #1: If the resulting context happens to be the same as it was -- fallback to plain collect
162             if (newContext == collectContext)
163                 return flowCollect(collector)
164             // #2: If we don't need to change the dispatcher we can go without channels
165             if (newContext[ContinuationInterceptor] == collectContext[ContinuationInterceptor])
166                 return collectWithContextUndispatched(collector, newContext)
167         }
168         // Slow-path: create the actual channel
169         super.collect(collector)
170     }
171 
172     // debug toString
toStringnull173     override fun toString(): String = "$flow -> ${super.toString()}"
174 }
175 
176 /**
177  * Simple channel flow operator: [flowOn], [buffer], or their fused combination.
178  */
179 internal class ChannelFlowOperatorImpl<T>(
180     flow: Flow<T>,
181     context: CoroutineContext = EmptyCoroutineContext,
182     capacity: Int = Channel.OPTIONAL_CHANNEL,
183     onBufferOverflow: BufferOverflow = BufferOverflow.SUSPEND
184 ) : ChannelFlowOperator<T, T>(flow, context, capacity, onBufferOverflow) {
185     override fun create(context: CoroutineContext, capacity: Int, onBufferOverflow: BufferOverflow): ChannelFlow<T> =
186         ChannelFlowOperatorImpl(flow, context, capacity, onBufferOverflow)
187 
188     override fun dropChannelOperators(): Flow<T> = flow
189 
190     override suspend fun flowCollect(collector: FlowCollector<T>) =
191         flow.collect(collector)
192 }
193 
194 // Now if the underlying collector was accepting concurrent emits, then this one is too
195 // todo: we might need to generalize this pattern for "thread-safe" operators that can fuse with channels
withUndispatchedContextCollectornull196 private fun <T> FlowCollector<T>.withUndispatchedContextCollector(emitContext: CoroutineContext): FlowCollector<T> = when (this) {
197     // SendingCollector & NopCollector do not care about the context at all and can be used as is
198     is SendingCollector, is NopCollector -> this
199     // Otherwise just wrap into UndispatchedContextCollector interface implementation
200     else -> UndispatchedContextCollector(this, emitContext)
201 }
202 
203 private class UndispatchedContextCollector<T>(
204     downstream: FlowCollector<T>,
205     private val emitContext: CoroutineContext
206 ) : FlowCollector<T> {
207     private val countOrElement = threadContextElements(emitContext) // precompute for fast withContextUndispatched
<lambda>null208     private val emitRef: suspend (T) -> Unit = { downstream.emit(it) } // allocate suspend function ref once on creation
209 
emitnull210     override suspend fun emit(value: T): Unit =
211         withContextUndispatched(emitContext, value, countOrElement, emitRef)
212 }
213 
214 // Efficiently computes block(value) in the newContext
215 internal suspend fun <T, V> withContextUndispatched(
216     newContext: CoroutineContext,
217     value: V,
218     countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
219     block: suspend (V) -> T
220 ): T =
221     suspendCoroutineUninterceptedOrReturn { uCont ->
222         withCoroutineContext(newContext, countOrElement) {
223             block.startCoroutineUninterceptedOrReturn(value, StackFrameContinuation(uCont, newContext))
224         }
225     }
226 
227 // Continuation that links the caller with uCont with walkable CoroutineStackFrame
228 private class StackFrameContinuation<T>(
229     private val uCont: Continuation<T>, override val context: CoroutineContext
230 ) : Continuation<T>, CoroutineStackFrame {
231 
232     override val callerFrame: CoroutineStackFrame?
233         get() = uCont as? CoroutineStackFrame
234 
resumeWithnull235     override fun resumeWith(result: Result<T>) {
236         uCont.resumeWith(result)
237     }
238 
getStackTraceElementnull239     override fun getStackTraceElement(): StackTraceElement? = null
240 }
241