xref: /aosp_15_r20/frameworks/base/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxDeferred.kt (revision d57664e9bc4670b3ecf6748a746a57c557b6bc9e)
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.systemui.kairos.internal
18 
19 import com.android.systemui.kairos.internal.util.Key
20 import com.android.systemui.kairos.internal.util.associateByIndexTo
21 import com.android.systemui.kairos.internal.util.hashString
22 import com.android.systemui.kairos.internal.util.mapParallel
23 import com.android.systemui.kairos.internal.util.mapValuesNotNullParallelTo
24 import com.android.systemui.kairos.util.Just
25 import com.android.systemui.kairos.util.Left
26 import com.android.systemui.kairos.util.Maybe
27 import com.android.systemui.kairos.util.None
28 import com.android.systemui.kairos.util.Right
29 import com.android.systemui.kairos.util.These
30 import com.android.systemui.kairos.util.flatMap
31 import com.android.systemui.kairos.util.getMaybe
32 import com.android.systemui.kairos.util.just
33 import com.android.systemui.kairos.util.maybeThat
34 import com.android.systemui.kairos.util.maybeThis
35 import com.android.systemui.kairos.util.merge
36 import com.android.systemui.kairos.util.orElseGet
37 import com.android.systemui.kairos.util.partitionEithers
38 import com.android.systemui.kairos.util.these
39 import java.util.TreeMap
40 import kotlinx.coroutines.coroutineScope
41 import kotlinx.coroutines.launch
42 import kotlinx.coroutines.sync.withLock
43 
44 internal class MuxDeferredNode<K : Any, V>(
45     lifecycle: MuxLifecycle<Map<K, V>>,
46     val spec: MuxActivator<Map<K, V>>,
47 ) : MuxNode<K, V, Map<K, V>>(lifecycle), Key<Map<K, V>> {
48 
49     val schedulable = Schedulable.M(this)
50 
51     @Volatile var patches: NodeConnection<Map<K, Maybe<TFlowImpl<V>>>>? = null
52     @Volatile var patchData: Map<K, Maybe<TFlowImpl<V>>>? = null
53 
54     override fun hasCurrentValueLocked(transactionStore: TransactionStore): Boolean =
55         transactionStore.contains(this)
56 
57     override suspend fun hasCurrentValue(transactionStore: TransactionStore): Boolean =
58         mutex.withLock { hasCurrentValueLocked(transactionStore) }
59 
60     override suspend fun visit(evalScope: EvalScope) {
61         val result = upstreamData.toMap()
62         upstreamData.clear()
63         val scheduleDownstream = result.isNotEmpty()
64         val compactDownstream = depthTracker.isDirty()
65         if (scheduleDownstream || compactDownstream) {
66             coroutineScope {
67                 mutex.withLock {
68                     if (compactDownstream) {
69                         depthTracker.applyChanges(
70                             coroutineScope = this,
71                             evalScope.scheduler,
72                             downstreamSet,
73                             muxNode = this@MuxDeferredNode,
74                         )
75                     }
76                     if (scheduleDownstream) {
77                         evalScope.setResult(this@MuxDeferredNode, result)
78                         if (!scheduleAll(downstreamSet, evalScope)) {
79                             evalScope.scheduleDeactivation(this@MuxDeferredNode)
80                         }
81                     }
82                 }
83             }
84         }
85     }
86 
87     override suspend fun getPushEvent(evalScope: EvalScope): Maybe<Map<K, V>> =
88         evalScope.getCurrentValue(key = this)
89 
90     private suspend fun compactIfNeeded(evalScope: EvalScope) {
91         depthTracker.propagateChanges(evalScope.compactor, this)
92     }
93 
94     override suspend fun doDeactivate() {
95         // Update lifecycle
96         lifecycle.mutex.withLock {
97             if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return@doDeactivate
98             lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
99         }
100         // Process branch nodes
101         coroutineScope {
102             switchedIn.values.forEach { branchNode ->
103                 branchNode.upstream.let {
104                     launch { it.removeDownstreamAndDeactivateIfNeeded(branchNode.schedulable) }
105                 }
106             }
107         }
108         // Process patch node
109         patches?.removeDownstreamAndDeactivateIfNeeded(schedulable)
110     }
111 
112     // MOVE phase
113     //  - concurrent moves may be occurring, but no more evals. all depth recalculations are
114     //    deferred to the end of this phase.
115     suspend fun performMove(evalScope: EvalScope) {
116         val patch = patchData ?: return
117         patchData = null
118 
119         // TODO: this logic is very similar to what's in MuxPromptMoving, maybe turn into an inline
120         //  fun?
121 
122         // We have a patch, process additions/updates and removals
123         val (adds, removes) =
124             patch
125                 .asSequence()
126                 .map { (k, newUpstream: Maybe<TFlowImpl<V>>) ->
127                     when (newUpstream) {
128                         is Just -> Left(k to newUpstream.value)
129                         None -> Right(k)
130                     }
131                 }
132                 .partitionEithers()
133 
134         val severed = mutableListOf<NodeConnection<*>>()
135 
136         coroutineScope {
137             // remove and sever
138             removes.forEach { k ->
139                 switchedIn.remove(k)?.let { branchNode: MuxBranchNode<K, V> ->
140                     val conn = branchNode.upstream
141                     severed.add(conn)
142                     launch { conn.removeDownstream(downstream = branchNode.schedulable) }
143                     depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
144                 }
145             }
146 
147             // add or replace
148             adds
149                 .mapParallel { (k, newUpstream: TFlowImpl<V>) ->
150                     val branchNode = MuxBranchNode(this@MuxDeferredNode, k)
151                     k to
152                         newUpstream.activate(evalScope, branchNode.schedulable)?.let { (conn, _) ->
153                             branchNode.apply { upstream = conn }
154                         }
155                 }
156                 .forEach { (k, newBranch: MuxBranchNode<K, V>?) ->
157                     // remove old and sever, if present
158                     switchedIn.remove(k)?.let { branchNode ->
159                         val conn = branchNode.upstream
160                         severed.add(conn)
161                         launch { conn.removeDownstream(downstream = branchNode.schedulable) }
162                         depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
163                     }
164 
165                     // add new
166                     newBranch?.let {
167                         switchedIn[k] = newBranch
168                         val branchDepthTracker = newBranch.upstream.depthTracker
169                         if (branchDepthTracker.snapshotIsDirect) {
170                             depthTracker.addDirectUpstream(
171                                 oldDepth = null,
172                                 newDepth = branchDepthTracker.snapshotDirectDepth,
173                             )
174                         } else {
175                             depthTracker.addIndirectUpstream(
176                                 oldDepth = null,
177                                 newDepth = branchDepthTracker.snapshotIndirectDepth,
178                             )
179                             depthTracker.updateIndirectRoots(
180                                 additions = branchDepthTracker.snapshotIndirectRoots,
181                                 butNot = this@MuxDeferredNode,
182                             )
183                         }
184                     }
185                 }
186         }
187 
188         coroutineScope {
189             for (severedNode in severed) {
190                 launch { severedNode.scheduleDeactivationIfNeeded(evalScope) }
191             }
192         }
193 
194         compactIfNeeded(evalScope)
195     }
196 
197     suspend fun removeDirectPatchNode(scheduler: Scheduler) {
198         mutex.withLock {
199             if (
200                 depthTracker.removeIndirectUpstream(depth = 0) or
201                     depthTracker.setIsIndirectRoot(false)
202             ) {
203                 depthTracker.schedule(scheduler, this)
204             }
205             patches = null
206         }
207     }
208 
209     suspend fun removeIndirectPatchNode(
210         scheduler: Scheduler,
211         depth: Int,
212         indirectSet: Set<MuxDeferredNode<*, *>>,
213     ) {
214         // indirectly connected patches forward the indirectSet
215         mutex.withLock {
216             if (
217                 depthTracker.updateIndirectRoots(removals = indirectSet) or
218                     depthTracker.removeIndirectUpstream(depth)
219             ) {
220                 depthTracker.schedule(scheduler, this)
221             }
222             patches = null
223         }
224     }
225 
226     suspend fun moveIndirectPatchNodeToDirect(
227         scheduler: Scheduler,
228         oldIndirectDepth: Int,
229         oldIndirectSet: Set<MuxDeferredNode<*, *>>,
230     ) {
231         // directly connected patches are stored as an indirect singleton set of the patchNode
232         mutex.withLock {
233             if (
234                 depthTracker.updateIndirectRoots(removals = oldIndirectSet) or
235                     depthTracker.removeIndirectUpstream(oldIndirectDepth) or
236                     depthTracker.setIsIndirectRoot(true)
237             ) {
238                 depthTracker.schedule(scheduler, this)
239             }
240         }
241     }
242 
243     suspend fun moveDirectPatchNodeToIndirect(
244         scheduler: Scheduler,
245         newIndirectDepth: Int,
246         newIndirectSet: Set<MuxDeferredNode<*, *>>,
247     ) {
248         // indirectly connected patches forward the indirectSet
249         mutex.withLock {
250             if (
251                 depthTracker.setIsIndirectRoot(false) or
252                     depthTracker.updateIndirectRoots(additions = newIndirectSet, butNot = this) or
253                     depthTracker.addIndirectUpstream(oldDepth = null, newDepth = newIndirectDepth)
254             ) {
255                 depthTracker.schedule(scheduler, this)
256             }
257         }
258     }
259 
260     suspend fun adjustIndirectPatchNode(
261         scheduler: Scheduler,
262         oldDepth: Int,
263         newDepth: Int,
264         removals: Set<MuxDeferredNode<*, *>>,
265         additions: Set<MuxDeferredNode<*, *>>,
266     ) {
267         // indirectly connected patches forward the indirectSet
268         mutex.withLock {
269             if (
270                 depthTracker.updateIndirectRoots(
271                     additions = additions,
272                     removals = removals,
273                     butNot = this,
274                 ) or depthTracker.addIndirectUpstream(oldDepth = oldDepth, newDepth = newDepth)
275             ) {
276                 depthTracker.schedule(scheduler, this)
277             }
278         }
279     }
280 
281     suspend fun scheduleMover(evalScope: EvalScope) {
282         patchData =
283             checkNotNull(patches) { "mux mover scheduled with unset patches upstream node" }
284                 .getPushEvent(evalScope)
285                 .orElseGet { null }
286         evalScope.scheduleMuxMover(this)
287     }
288 
289     override fun toString(): String = "${this::class.simpleName}@$hashString"
290 }
291 
switchDeferredImplSinglenull292 internal inline fun <A> switchDeferredImplSingle(
293     crossinline getStorage: suspend EvalScope.() -> TFlowImpl<A>,
294     crossinline getPatches: suspend EvalScope.() -> TFlowImpl<TFlowImpl<A>>,
295 ): TFlowImpl<A> =
296     mapImpl({
297         switchDeferredImpl(
298             getStorage = { mapOf(Unit to getStorage()) },
299             getPatches = { mapImpl(getPatches) { newFlow -> mapOf(Unit to just(newFlow)) } },
300         )
301     }) { map ->
302         map.getValue(Unit)
303     }
304 
switchDeferredImplnull305 internal fun <K : Any, A> switchDeferredImpl(
306     getStorage: suspend EvalScope.() -> Map<K, TFlowImpl<A>>,
307     getPatches: suspend EvalScope.() -> TFlowImpl<Map<K, Maybe<TFlowImpl<A>>>>,
308 ): TFlowImpl<Map<K, A>> =
309     MuxLifecycle(
310         object : MuxActivator<Map<K, A>> {
311             override suspend fun activate(
312                 evalScope: EvalScope,
313                 lifecycle: MuxLifecycle<Map<K, A>>,
314             ): MuxNode<*, *, Map<K, A>>? {
315                 val storage: Map<K, TFlowImpl<A>> = getStorage(evalScope)
316                 // Initialize mux node and switched-in connections.
317                 val muxNode =
318                     MuxDeferredNode(lifecycle, this).apply {
319                         storage.mapValuesNotNullParallelTo(switchedIn) { (key, flow) ->
320                             val branchNode = MuxBranchNode(this@apply, key)
321                             flow.activate(evalScope, branchNode.schedulable)?.let {
322                                 (conn, needsEval) ->
323                                 branchNode
324                                     .apply { upstream = conn }
325                                     .also {
326                                         if (needsEval) {
327                                             val result = conn.getPushEvent(evalScope)
328                                             if (result is Just) {
329                                                 upstreamData[key] = result.value
330                                             }
331                                         }
332                                     }
333                             }
334                         }
335                     }
336                 // Update depth based on all initial switched-in nodes.
337                 muxNode.switchedIn.values.forEach { branch ->
338                     val conn = branch.upstream
339                     if (conn.depthTracker.snapshotIsDirect) {
340                         muxNode.depthTracker.addDirectUpstream(
341                             oldDepth = null,
342                             newDepth = conn.depthTracker.snapshotDirectDepth,
343                         )
344                     } else {
345                         muxNode.depthTracker.addIndirectUpstream(
346                             oldDepth = null,
347                             newDepth = conn.depthTracker.snapshotIndirectDepth,
348                         )
349                         muxNode.depthTracker.updateIndirectRoots(
350                             additions = conn.depthTracker.snapshotIndirectRoots,
351                             butNot = muxNode,
352                         )
353                     }
354                 }
355                 // We don't have our patches connection established yet, so for now pretend we have
356                 // a direct connection to patches. We will update downstream nodes later if this
357                 // turns out to be a lie.
358                 muxNode.depthTracker.setIsIndirectRoot(true)
359                 muxNode.depthTracker.reset()
360 
361                 // Setup patches connection; deferring allows for a recursive connection, where
362                 // muxNode is downstream of itself via patches.
363                 var isIndirect = true
364                 evalScope.deferAction {
365                     val (patchesConn, needsEval) =
366                         getPatches(evalScope).activate(evalScope, downstream = muxNode.schedulable)
367                             ?: run {
368                                 isIndirect = false
369                                 // Turns out we can't connect to patches, so update our depth and
370                                 // propagate
371                                 muxNode.mutex.withLock {
372                                     if (muxNode.depthTracker.setIsIndirectRoot(false)) {
373                                         muxNode.depthTracker.schedule(evalScope.scheduler, muxNode)
374                                     }
375                                 }
376                                 return@deferAction
377                             }
378                     muxNode.patches = patchesConn
379 
380                     if (!patchesConn.schedulerUpstream.depthTracker.snapshotIsDirect) {
381                         // Turns out patches is indirect, so we are not a root. Update depth and
382                         // propagate.
383                         muxNode.mutex.withLock {
384                             if (
385                                 muxNode.depthTracker.setIsIndirectRoot(false) or
386                                     muxNode.depthTracker.addIndirectUpstream(
387                                         oldDepth = null,
388                                         newDepth = patchesConn.depthTracker.snapshotIndirectDepth,
389                                     ) or
390                                     muxNode.depthTracker.updateIndirectRoots(
391                                         additions = patchesConn.depthTracker.snapshotIndirectRoots
392                                     )
393                             ) {
394                                 muxNode.depthTracker.schedule(evalScope.scheduler, muxNode)
395                             }
396                         }
397                     }
398                     // Schedule mover to process patch emission at the end of this transaction, if
399                     // needed.
400                     if (needsEval) {
401                         val result = patchesConn.getPushEvent(evalScope)
402                         if (result is Just) {
403                             muxNode.patchData = result.value
404                             evalScope.scheduleMuxMover(muxNode)
405                         }
406                     }
407                 }
408 
409                 // Schedule for evaluation if any switched-in nodes have already emitted within
410                 // this transaction.
411                 if (muxNode.upstreamData.isNotEmpty()) {
412                     evalScope.schedule(muxNode)
413                 }
414                 return muxNode.takeUnless { muxNode.switchedIn.isEmpty() && !isIndirect }
415             }
416         }
417     )
418 
mergeNodesnull419 internal inline fun <A> mergeNodes(
420     crossinline getPulse: suspend EvalScope.() -> TFlowImpl<A>,
421     crossinline getOther: suspend EvalScope.() -> TFlowImpl<A>,
422     crossinline f: suspend EvalScope.(A, A) -> A,
423 ): TFlowImpl<A> {
424     val merged =
425         mapImpl({ mergeNodes(getPulse, getOther) }) { these ->
426             these.merge { thiz, that -> f(thiz, that) }
427         }
428     return merged.cached()
429 }
430 
mergeNodesnull431 internal inline fun <A, B> mergeNodes(
432     crossinline getPulse: suspend EvalScope.() -> TFlowImpl<A>,
433     crossinline getOther: suspend EvalScope.() -> TFlowImpl<B>,
434 ): TFlowImpl<These<A, B>> {
435     val storage =
436         mapOf(
437             0 to mapImpl(getPulse) { These.thiz<A, B>(it) },
438             1 to mapImpl(getOther) { These.that(it) },
439         )
440     val switchNode = switchDeferredImpl(getStorage = { storage }, getPatches = { neverImpl })
441     val merged =
442         mapImpl({ switchNode }) { mergeResults ->
443             val first = mergeResults.getMaybe(0).flatMap { it.maybeThis() }
444             val second = mergeResults.getMaybe(1).flatMap { it.maybeThat() }
445             these(first, second).orElseGet { error("unexpected missing merge result") }
446         }
447     return merged.cached()
448 }
449 
mergeNodesnull450 internal inline fun <A> mergeNodes(
451     crossinline getPulses: suspend EvalScope.() -> Iterable<TFlowImpl<A>>
452 ): TFlowImpl<List<A>> {
453     val switchNode =
454         switchDeferredImpl(
455             getStorage = { getPulses().associateByIndexTo(TreeMap()) },
456             getPatches = { neverImpl },
457         )
458     val merged = mapImpl({ switchNode }) { mergeResults -> mergeResults.values.toList() }
459     return merged.cached()
460 }
461 
mergeNodesLeftnull462 internal inline fun <A> mergeNodesLeft(
463     crossinline getPulses: suspend EvalScope.() -> Iterable<TFlowImpl<A>>
464 ): TFlowImpl<A> {
465     val switchNode =
466         switchDeferredImpl(
467             getStorage = { getPulses().associateByIndexTo(TreeMap()) },
468             getPatches = { neverImpl },
469         )
470     val merged =
471         mapImpl({ switchNode }) { mergeResults: Map<Int, A> -> mergeResults.values.first() }
472     return merged.cached()
473 }
474