xref: /aosp_15_r20/frameworks/base/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/MuxPrompt.kt (revision d57664e9bc4670b3ecf6748a746a57c557b6bc9e)
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.systemui.kairos.internal
18 
19 import com.android.systemui.kairos.internal.util.Key
20 import com.android.systemui.kairos.internal.util.launchImmediate
21 import com.android.systemui.kairos.internal.util.mapParallel
22 import com.android.systemui.kairos.internal.util.mapValuesNotNullParallelTo
23 import com.android.systemui.kairos.util.Just
24 import com.android.systemui.kairos.util.Left
25 import com.android.systemui.kairos.util.Maybe
26 import com.android.systemui.kairos.util.None
27 import com.android.systemui.kairos.util.Right
28 import com.android.systemui.kairos.util.filterJust
29 import com.android.systemui.kairos.util.map
30 import com.android.systemui.kairos.util.partitionEithers
31 import kotlinx.coroutines.CoroutineScope
32 import kotlinx.coroutines.async
33 import kotlinx.coroutines.awaitAll
34 import kotlinx.coroutines.coroutineScope
35 import kotlinx.coroutines.launch
36 import kotlinx.coroutines.sync.withLock
37 
38 internal class MuxPromptMovingNode<K : Any, V>(
39     lifecycle: MuxLifecycle<Pair<Map<K, V>, Map<K, PullNode<V>>?>>,
40     private val spec: MuxActivator<Pair<Map<K, V>, Map<K, PullNode<V>>?>>,
41 ) :
42     MuxNode<K, V, Pair<Map<K, V>, Map<K, PullNode<V>>?>>(lifecycle),
43     Key<Pair<Map<K, V>, Map<K, PullNode<V>>?>> {
44 
45     @Volatile var patchData: Map<K, Maybe<TFlowImpl<V>>>? = null
46     @Volatile var patches: MuxPromptPatchNode<K, V>? = null
47 
48     @Volatile private var reEval: Pair<Map<K, V>, Map<K, PullNode<V>>?>? = null
49 
50     override fun hasCurrentValueLocked(transactionStore: TransactionStore): Boolean =
51         transactionStore.contains(this)
52 
53     override suspend fun hasCurrentValue(transactionStore: TransactionStore): Boolean =
54         mutex.withLock { hasCurrentValueLocked(transactionStore) }
55 
56     override suspend fun visit(evalScope: EvalScope) {
57         val preSwitchResults: Map<K, V> = upstreamData.toMap()
58         upstreamData.clear()
59 
60         val patch: Map<K, Maybe<TFlowImpl<V>>>? = patchData
61         patchData = null
62 
63         val (reschedule, evalResult) =
64             reEval?.let { false to it }
65                 ?: if (preSwitchResults.isNotEmpty() || patch?.isNotEmpty() == true) {
66                     doEval(preSwitchResults, patch, evalScope)
67                 } else {
68                     false to null
69                 }
70         reEval = null
71 
72         if (reschedule || depthTracker.dirty_depthIncreased()) {
73             reEval = evalResult
74             // Can't schedule downstream yet, need to compact first
75             if (depthTracker.dirty_depthIncreased()) {
76                 depthTracker.schedule(evalScope.compactor, node = this)
77             }
78             evalScope.schedule(this)
79         } else {
80             val compactDownstream = depthTracker.isDirty()
81             if (evalResult != null || compactDownstream) {
82                 coroutineScope {
83                     mutex.withLock {
84                         if (compactDownstream) {
85                             adjustDownstreamDepths(evalScope, coroutineScope = this)
86                         }
87                         if (evalResult != null) {
88                             evalScope.setResult(this@MuxPromptMovingNode, evalResult)
89                             if (!scheduleAll(downstreamSet, evalScope)) {
90                                 evalScope.scheduleDeactivation(this@MuxPromptMovingNode)
91                             }
92                         }
93                     }
94                 }
95             }
96         }
97     }
98 
99     private suspend fun doEval(
100         preSwitchResults: Map<K, V>,
101         patch: Map<K, Maybe<TFlowImpl<V>>>?,
102         evalScope: EvalScope,
103     ): Pair<Boolean, Pair<Map<K, V>, Map<K, PullNode<V>>?>?> {
104         val newlySwitchedIn: Map<K, PullNode<V>>? =
105             patch?.let {
106                 // We have a patch, process additions/updates and removals
107                 val (adds, removes) =
108                     patch
109                         .asSequence()
110                         .map { (k, newUpstream: Maybe<TFlowImpl<V>>) ->
111                             when (newUpstream) {
112                                 is Just -> Left(k to newUpstream.value)
113                                 None -> Right(k)
114                             }
115                         }
116                         .partitionEithers()
117 
118                 val additionsAndUpdates = mutableMapOf<K, PullNode<V>>()
119                 val severed = mutableListOf<NodeConnection<*>>()
120 
121                 coroutineScope {
122                     // remove and sever
123                     removes.forEach { k ->
124                         switchedIn.remove(k)?.let { branchNode: MuxBranchNode<K, V> ->
125                             val conn: NodeConnection<V> = branchNode.upstream
126                             severed.add(conn)
127                             launchImmediate {
128                                 conn.removeDownstream(downstream = branchNode.schedulable)
129                             }
130                             depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
131                         }
132                     }
133 
134                     // add or replace
135                     adds
136                         .mapParallel { (k, newUpstream: TFlowImpl<V>) ->
137                             val branchNode = MuxBranchNode(this@MuxPromptMovingNode, k)
138                             k to
139                                 newUpstream.activate(evalScope, branchNode.schedulable)?.let {
140                                     (conn, _) ->
141                                     branchNode.apply { upstream = conn }
142                                 }
143                         }
144                         .forEach { (k, newBranch: MuxBranchNode<K, V>?) ->
145                             // remove old and sever, if present
146                             switchedIn.remove(k)?.let { oldBranch: MuxBranchNode<K, V> ->
147                                 val conn: NodeConnection<V> = oldBranch.upstream
148                                 severed.add(conn)
149                                 launchImmediate {
150                                     conn.removeDownstream(downstream = oldBranch.schedulable)
151                                 }
152                                 depthTracker.removeDirectUpstream(
153                                     conn.depthTracker.snapshotDirectDepth
154                                 )
155                             }
156 
157                             // add new
158                             newBranch?.let {
159                                 switchedIn[k] = newBranch
160                                 additionsAndUpdates[k] = newBranch.upstream.directUpstream
161                                 val branchDepthTracker = newBranch.upstream.depthTracker
162                                 if (branchDepthTracker.snapshotIsDirect) {
163                                     depthTracker.addDirectUpstream(
164                                         oldDepth = null,
165                                         newDepth = branchDepthTracker.snapshotDirectDepth,
166                                     )
167                                 } else {
168                                     depthTracker.addIndirectUpstream(
169                                         oldDepth = null,
170                                         newDepth = branchDepthTracker.snapshotIndirectDepth,
171                                     )
172                                     depthTracker.updateIndirectRoots(
173                                         additions = branchDepthTracker.snapshotIndirectRoots,
174                                         butNot = null,
175                                     )
176                                 }
177                             }
178                         }
179                 }
180 
181                 coroutineScope {
182                     for (severedNode in severed) {
183                         launch { severedNode.scheduleDeactivationIfNeeded(evalScope) }
184                     }
185                 }
186 
187                 additionsAndUpdates.takeIf { it.isNotEmpty() }
188             }
189 
190         return if (preSwitchResults.isNotEmpty() || newlySwitchedIn != null) {
191             (newlySwitchedIn != null) to (preSwitchResults to newlySwitchedIn)
192         } else {
193             false to null
194         }
195     }
196 
197     private suspend fun adjustDownstreamDepths(
198         evalScope: EvalScope,
199         coroutineScope: CoroutineScope,
200     ) {
201         if (depthTracker.dirty_depthIncreased()) {
202             // schedule downstream nodes on the compaction scheduler; this scheduler is drained at
203             // the end of this eval depth, so that all depth increases are applied before we advance
204             // the eval step
205             depthTracker.schedule(evalScope.compactor, node = this@MuxPromptMovingNode)
206         } else if (depthTracker.isDirty()) {
207             // schedule downstream nodes on the eval scheduler; this is more efficient and is only
208             // safe if the depth hasn't increased
209             depthTracker.applyChanges(
210                 coroutineScope,
211                 evalScope.scheduler,
212                 downstreamSet,
213                 muxNode = this@MuxPromptMovingNode,
214             )
215         }
216     }
217 
218     override suspend fun getPushEvent(
219         evalScope: EvalScope
220     ): Maybe<Pair<Map<K, V>, Map<K, PullNode<V>>?>> = evalScope.getCurrentValue(key = this)
221 
222     override suspend fun doDeactivate() {
223         // Update lifecycle
224         lifecycle.mutex.withLock {
225             if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return@doDeactivate
226             lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
227         }
228         // Process branch nodes
229         switchedIn.values.forEach { branchNode ->
230             branchNode.upstream.removeDownstreamAndDeactivateIfNeeded(
231                 downstream = branchNode.schedulable
232             )
233         }
234         // Process patch node
235         patches?.let { patches ->
236             patches.upstream.removeDownstreamAndDeactivateIfNeeded(downstream = patches.schedulable)
237         }
238     }
239 
240     suspend fun removeIndirectPatchNode(
241         scheduler: Scheduler,
242         oldDepth: Int,
243         indirectSet: Set<MuxDeferredNode<*, *>>,
244     ) {
245         mutex.withLock {
246             patches = null
247             if (
248                 depthTracker.removeIndirectUpstream(oldDepth) or
249                     depthTracker.updateIndirectRoots(removals = indirectSet)
250             ) {
251                 depthTracker.schedule(scheduler, this)
252             }
253         }
254     }
255 
256     suspend fun removeDirectPatchNode(scheduler: Scheduler, depth: Int) {
257         mutex.withLock {
258             patches = null
259             if (depthTracker.removeDirectUpstream(depth)) {
260                 depthTracker.schedule(scheduler, this)
261             }
262         }
263     }
264 }
265 
266 internal class MuxPromptEvalNode<K, V>(
267     private val movingNode: PullNode<Pair<Map<K, V>, Map<K, PullNode<V>>?>>
268 ) : PullNode<Map<K, V>> {
getPushEventnull269     override suspend fun getPushEvent(evalScope: EvalScope): Maybe<Map<K, V>> =
270         movingNode.getPushEvent(evalScope).map { (preSwitchResults, newlySwitchedIn) ->
271             coroutineScope {
272                 newlySwitchedIn
273                     ?.map { (k, v) -> async { v.getPushEvent(evalScope).map { k to it } } }
274                     ?.awaitAll()
275                     ?.asSequence()
276                     ?.filterJust()
277                     ?.toMap(preSwitchResults.toMutableMap()) ?: preSwitchResults
278             }
279         }
280 }
281 
282 // TODO: inner class?
283 internal class MuxPromptPatchNode<K : Any, V>(private val muxNode: MuxPromptMovingNode<K, V>) :
284     SchedulableNode {
285 
286     val schedulable = Schedulable.N(this)
287 
288     lateinit var upstream: NodeConnection<Map<K, Maybe<TFlowImpl<V>>>>
289 
schedulenull290     override suspend fun schedule(evalScope: EvalScope) {
291         val upstreamResult = upstream.getPushEvent(evalScope)
292         if (upstreamResult is Just) {
293             muxNode.patchData = upstreamResult.value
294             evalScope.schedule(muxNode)
295         }
296     }
297 
adjustDirectUpstreamnull298     override suspend fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
299         muxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth)
300     }
301 
moveIndirectUpstreamToDirectnull302     override suspend fun moveIndirectUpstreamToDirect(
303         scheduler: Scheduler,
304         oldIndirectDepth: Int,
305         oldIndirectSet: Set<MuxDeferredNode<*, *>>,
306         newDirectDepth: Int,
307     ) {
308         muxNode.moveIndirectUpstreamToDirect(
309             scheduler,
310             oldIndirectDepth,
311             oldIndirectSet,
312             newDirectDepth,
313         )
314     }
315 
adjustIndirectUpstreamnull316     override suspend fun adjustIndirectUpstream(
317         scheduler: Scheduler,
318         oldDepth: Int,
319         newDepth: Int,
320         removals: Set<MuxDeferredNode<*, *>>,
321         additions: Set<MuxDeferredNode<*, *>>,
322     ) {
323         muxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions)
324     }
325 
moveDirectUpstreamToIndirectnull326     override suspend fun moveDirectUpstreamToIndirect(
327         scheduler: Scheduler,
328         oldDirectDepth: Int,
329         newIndirectDepth: Int,
330         newIndirectSet: Set<MuxDeferredNode<*, *>>,
331     ) {
332         muxNode.moveDirectUpstreamToIndirect(
333             scheduler,
334             oldDirectDepth,
335             newIndirectDepth,
336             newIndirectSet,
337         )
338     }
339 
removeDirectUpstreamnull340     override suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) {
341         muxNode.removeDirectPatchNode(scheduler, depth)
342     }
343 
removeIndirectUpstreamnull344     override suspend fun removeIndirectUpstream(
345         scheduler: Scheduler,
346         depth: Int,
347         indirectSet: Set<MuxDeferredNode<*, *>>,
348     ) {
349         muxNode.removeIndirectPatchNode(scheduler, depth, indirectSet)
350     }
351 }
352 
switchPromptImplnull353 internal fun <K : Any, A> switchPromptImpl(
354     getStorage: suspend EvalScope.() -> Map<K, TFlowImpl<A>>,
355     getPatches: suspend EvalScope.() -> TFlowImpl<Map<K, Maybe<TFlowImpl<A>>>>,
356 ): TFlowImpl<Map<K, A>> {
357     val moving =
358         MuxLifecycle(
359             object : MuxActivator<Pair<Map<K, A>, Map<K, PullNode<A>>?>> {
360                 override suspend fun activate(
361                     evalScope: EvalScope,
362                     lifecycle: MuxLifecycle<Pair<Map<K, A>, Map<K, PullNode<A>>?>>,
363                 ): MuxNode<*, *, Pair<Map<K, A>, Map<K, PullNode<A>>?>>? {
364                     val storage: Map<K, TFlowImpl<A>> = getStorage(evalScope)
365                     // Initialize mux node and switched-in connections.
366                     val movingNode =
367                         MuxPromptMovingNode(lifecycle, this).apply {
368                             coroutineScope {
369                                 launch {
370                                     storage.mapValuesNotNullParallelTo(switchedIn) { (key, flow) ->
371                                         val branchNode = MuxBranchNode(this@apply, key)
372                                         flow
373                                             .activate(
374                                                 evalScope = evalScope,
375                                                 downstream = branchNode.schedulable,
376                                             )
377                                             ?.let { (conn, needsEval) ->
378                                                 branchNode
379                                                     .apply { upstream = conn }
380                                                     .also {
381                                                         if (needsEval) {
382                                                             val result =
383                                                                 conn.getPushEvent(evalScope)
384                                                             if (result is Just) {
385                                                                 upstreamData[key] = result.value
386                                                             }
387                                                         }
388                                                     }
389                                             }
390                                     }
391                                 }
392                                 // Setup patches connection
393                                 val patchNode = MuxPromptPatchNode(this@apply)
394                                 getPatches(evalScope)
395                                     .activate(
396                                         evalScope = evalScope,
397                                         downstream = patchNode.schedulable,
398                                     )
399                                     ?.let { (conn, needsEval) ->
400                                         patchNode.upstream = conn
401                                         patches = patchNode
402 
403                                         if (needsEval) {
404                                             val result = conn.getPushEvent(evalScope)
405                                             if (result is Just) {
406                                                 patchData = result.value
407                                             }
408                                         }
409                                     }
410                             }
411                         }
412                     // Update depth based on all initial switched-in nodes.
413                     movingNode.switchedIn.values.forEach { branch ->
414                         val conn = branch.upstream
415                         if (conn.depthTracker.snapshotIsDirect) {
416                             movingNode.depthTracker.addDirectUpstream(
417                                 oldDepth = null,
418                                 newDepth = conn.depthTracker.snapshotDirectDepth,
419                             )
420                         } else {
421                             movingNode.depthTracker.addIndirectUpstream(
422                                 oldDepth = null,
423                                 newDepth = conn.depthTracker.snapshotIndirectDepth,
424                             )
425                             movingNode.depthTracker.updateIndirectRoots(
426                                 additions = conn.depthTracker.snapshotIndirectRoots,
427                                 butNot = null,
428                             )
429                         }
430                     }
431                     // Update depth based on patches node.
432                     movingNode.patches?.upstream?.let { conn ->
433                         if (conn.depthTracker.snapshotIsDirect) {
434                             movingNode.depthTracker.addDirectUpstream(
435                                 oldDepth = null,
436                                 newDepth = conn.depthTracker.snapshotDirectDepth,
437                             )
438                         } else {
439                             movingNode.depthTracker.addIndirectUpstream(
440                                 oldDepth = null,
441                                 newDepth = conn.depthTracker.snapshotIndirectDepth,
442                             )
443                             movingNode.depthTracker.updateIndirectRoots(
444                                 additions = conn.depthTracker.snapshotIndirectRoots,
445                                 butNot = null,
446                             )
447                         }
448                     }
449                     movingNode.depthTracker.reset()
450 
451                     // Schedule for evaluation if any switched-in nodes or the patches node have
452                     // already emitted within this transaction.
453                     if (movingNode.patchData != null || movingNode.upstreamData.isNotEmpty()) {
454                         evalScope.schedule(movingNode)
455                     }
456 
457                     return movingNode.takeUnless { it.patches == null && it.switchedIn.isEmpty() }
458                 }
459             }
460         )
461 
462     val eval = TFlowCheap { downstream ->
463         moving.activate(evalScope = this, downstream)?.let { (connection, needsEval) ->
464             val evalNode = MuxPromptEvalNode(connection.directUpstream)
465             ActivationResult(
466                 connection = NodeConnection(evalNode, connection.schedulerUpstream),
467                 needsEval = needsEval,
468             )
469         }
470     }
471     return eval.cached()
472 }
473