1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.systemui.kairos.internal
18
19 import com.android.systemui.kairos.internal.util.Key
20 import com.android.systemui.kairos.internal.util.launchImmediate
21 import com.android.systemui.kairos.internal.util.mapParallel
22 import com.android.systemui.kairos.internal.util.mapValuesNotNullParallelTo
23 import com.android.systemui.kairos.util.Just
24 import com.android.systemui.kairos.util.Left
25 import com.android.systemui.kairos.util.Maybe
26 import com.android.systemui.kairos.util.None
27 import com.android.systemui.kairos.util.Right
28 import com.android.systemui.kairos.util.filterJust
29 import com.android.systemui.kairos.util.map
30 import com.android.systemui.kairos.util.partitionEithers
31 import kotlinx.coroutines.CoroutineScope
32 import kotlinx.coroutines.async
33 import kotlinx.coroutines.awaitAll
34 import kotlinx.coroutines.coroutineScope
35 import kotlinx.coroutines.launch
36 import kotlinx.coroutines.sync.withLock
37
38 internal class MuxPromptMovingNode<K : Any, V>(
39 lifecycle: MuxLifecycle<Pair<Map<K, V>, Map<K, PullNode<V>>?>>,
40 private val spec: MuxActivator<Pair<Map<K, V>, Map<K, PullNode<V>>?>>,
41 ) :
42 MuxNode<K, V, Pair<Map<K, V>, Map<K, PullNode<V>>?>>(lifecycle),
43 Key<Pair<Map<K, V>, Map<K, PullNode<V>>?>> {
44
45 @Volatile var patchData: Map<K, Maybe<TFlowImpl<V>>>? = null
46 @Volatile var patches: MuxPromptPatchNode<K, V>? = null
47
48 @Volatile private var reEval: Pair<Map<K, V>, Map<K, PullNode<V>>?>? = null
49
50 override fun hasCurrentValueLocked(transactionStore: TransactionStore): Boolean =
51 transactionStore.contains(this)
52
53 override suspend fun hasCurrentValue(transactionStore: TransactionStore): Boolean =
54 mutex.withLock { hasCurrentValueLocked(transactionStore) }
55
56 override suspend fun visit(evalScope: EvalScope) {
57 val preSwitchResults: Map<K, V> = upstreamData.toMap()
58 upstreamData.clear()
59
60 val patch: Map<K, Maybe<TFlowImpl<V>>>? = patchData
61 patchData = null
62
63 val (reschedule, evalResult) =
64 reEval?.let { false to it }
65 ?: if (preSwitchResults.isNotEmpty() || patch?.isNotEmpty() == true) {
66 doEval(preSwitchResults, patch, evalScope)
67 } else {
68 false to null
69 }
70 reEval = null
71
72 if (reschedule || depthTracker.dirty_depthIncreased()) {
73 reEval = evalResult
74 // Can't schedule downstream yet, need to compact first
75 if (depthTracker.dirty_depthIncreased()) {
76 depthTracker.schedule(evalScope.compactor, node = this)
77 }
78 evalScope.schedule(this)
79 } else {
80 val compactDownstream = depthTracker.isDirty()
81 if (evalResult != null || compactDownstream) {
82 coroutineScope {
83 mutex.withLock {
84 if (compactDownstream) {
85 adjustDownstreamDepths(evalScope, coroutineScope = this)
86 }
87 if (evalResult != null) {
88 evalScope.setResult(this@MuxPromptMovingNode, evalResult)
89 if (!scheduleAll(downstreamSet, evalScope)) {
90 evalScope.scheduleDeactivation(this@MuxPromptMovingNode)
91 }
92 }
93 }
94 }
95 }
96 }
97 }
98
99 private suspend fun doEval(
100 preSwitchResults: Map<K, V>,
101 patch: Map<K, Maybe<TFlowImpl<V>>>?,
102 evalScope: EvalScope,
103 ): Pair<Boolean, Pair<Map<K, V>, Map<K, PullNode<V>>?>?> {
104 val newlySwitchedIn: Map<K, PullNode<V>>? =
105 patch?.let {
106 // We have a patch, process additions/updates and removals
107 val (adds, removes) =
108 patch
109 .asSequence()
110 .map { (k, newUpstream: Maybe<TFlowImpl<V>>) ->
111 when (newUpstream) {
112 is Just -> Left(k to newUpstream.value)
113 None -> Right(k)
114 }
115 }
116 .partitionEithers()
117
118 val additionsAndUpdates = mutableMapOf<K, PullNode<V>>()
119 val severed = mutableListOf<NodeConnection<*>>()
120
121 coroutineScope {
122 // remove and sever
123 removes.forEach { k ->
124 switchedIn.remove(k)?.let { branchNode: MuxBranchNode<K, V> ->
125 val conn: NodeConnection<V> = branchNode.upstream
126 severed.add(conn)
127 launchImmediate {
128 conn.removeDownstream(downstream = branchNode.schedulable)
129 }
130 depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
131 }
132 }
133
134 // add or replace
135 adds
136 .mapParallel { (k, newUpstream: TFlowImpl<V>) ->
137 val branchNode = MuxBranchNode(this@MuxPromptMovingNode, k)
138 k to
139 newUpstream.activate(evalScope, branchNode.schedulable)?.let {
140 (conn, _) ->
141 branchNode.apply { upstream = conn }
142 }
143 }
144 .forEach { (k, newBranch: MuxBranchNode<K, V>?) ->
145 // remove old and sever, if present
146 switchedIn.remove(k)?.let { oldBranch: MuxBranchNode<K, V> ->
147 val conn: NodeConnection<V> = oldBranch.upstream
148 severed.add(conn)
149 launchImmediate {
150 conn.removeDownstream(downstream = oldBranch.schedulable)
151 }
152 depthTracker.removeDirectUpstream(
153 conn.depthTracker.snapshotDirectDepth
154 )
155 }
156
157 // add new
158 newBranch?.let {
159 switchedIn[k] = newBranch
160 additionsAndUpdates[k] = newBranch.upstream.directUpstream
161 val branchDepthTracker = newBranch.upstream.depthTracker
162 if (branchDepthTracker.snapshotIsDirect) {
163 depthTracker.addDirectUpstream(
164 oldDepth = null,
165 newDepth = branchDepthTracker.snapshotDirectDepth,
166 )
167 } else {
168 depthTracker.addIndirectUpstream(
169 oldDepth = null,
170 newDepth = branchDepthTracker.snapshotIndirectDepth,
171 )
172 depthTracker.updateIndirectRoots(
173 additions = branchDepthTracker.snapshotIndirectRoots,
174 butNot = null,
175 )
176 }
177 }
178 }
179 }
180
181 coroutineScope {
182 for (severedNode in severed) {
183 launch { severedNode.scheduleDeactivationIfNeeded(evalScope) }
184 }
185 }
186
187 additionsAndUpdates.takeIf { it.isNotEmpty() }
188 }
189
190 return if (preSwitchResults.isNotEmpty() || newlySwitchedIn != null) {
191 (newlySwitchedIn != null) to (preSwitchResults to newlySwitchedIn)
192 } else {
193 false to null
194 }
195 }
196
197 private suspend fun adjustDownstreamDepths(
198 evalScope: EvalScope,
199 coroutineScope: CoroutineScope,
200 ) {
201 if (depthTracker.dirty_depthIncreased()) {
202 // schedule downstream nodes on the compaction scheduler; this scheduler is drained at
203 // the end of this eval depth, so that all depth increases are applied before we advance
204 // the eval step
205 depthTracker.schedule(evalScope.compactor, node = this@MuxPromptMovingNode)
206 } else if (depthTracker.isDirty()) {
207 // schedule downstream nodes on the eval scheduler; this is more efficient and is only
208 // safe if the depth hasn't increased
209 depthTracker.applyChanges(
210 coroutineScope,
211 evalScope.scheduler,
212 downstreamSet,
213 muxNode = this@MuxPromptMovingNode,
214 )
215 }
216 }
217
218 override suspend fun getPushEvent(
219 evalScope: EvalScope
220 ): Maybe<Pair<Map<K, V>, Map<K, PullNode<V>>?>> = evalScope.getCurrentValue(key = this)
221
222 override suspend fun doDeactivate() {
223 // Update lifecycle
224 lifecycle.mutex.withLock {
225 if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return@doDeactivate
226 lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
227 }
228 // Process branch nodes
229 switchedIn.values.forEach { branchNode ->
230 branchNode.upstream.removeDownstreamAndDeactivateIfNeeded(
231 downstream = branchNode.schedulable
232 )
233 }
234 // Process patch node
235 patches?.let { patches ->
236 patches.upstream.removeDownstreamAndDeactivateIfNeeded(downstream = patches.schedulable)
237 }
238 }
239
240 suspend fun removeIndirectPatchNode(
241 scheduler: Scheduler,
242 oldDepth: Int,
243 indirectSet: Set<MuxDeferredNode<*, *>>,
244 ) {
245 mutex.withLock {
246 patches = null
247 if (
248 depthTracker.removeIndirectUpstream(oldDepth) or
249 depthTracker.updateIndirectRoots(removals = indirectSet)
250 ) {
251 depthTracker.schedule(scheduler, this)
252 }
253 }
254 }
255
256 suspend fun removeDirectPatchNode(scheduler: Scheduler, depth: Int) {
257 mutex.withLock {
258 patches = null
259 if (depthTracker.removeDirectUpstream(depth)) {
260 depthTracker.schedule(scheduler, this)
261 }
262 }
263 }
264 }
265
266 internal class MuxPromptEvalNode<K, V>(
267 private val movingNode: PullNode<Pair<Map<K, V>, Map<K, PullNode<V>>?>>
268 ) : PullNode<Map<K, V>> {
getPushEventnull269 override suspend fun getPushEvent(evalScope: EvalScope): Maybe<Map<K, V>> =
270 movingNode.getPushEvent(evalScope).map { (preSwitchResults, newlySwitchedIn) ->
271 coroutineScope {
272 newlySwitchedIn
273 ?.map { (k, v) -> async { v.getPushEvent(evalScope).map { k to it } } }
274 ?.awaitAll()
275 ?.asSequence()
276 ?.filterJust()
277 ?.toMap(preSwitchResults.toMutableMap()) ?: preSwitchResults
278 }
279 }
280 }
281
282 // TODO: inner class?
283 internal class MuxPromptPatchNode<K : Any, V>(private val muxNode: MuxPromptMovingNode<K, V>) :
284 SchedulableNode {
285
286 val schedulable = Schedulable.N(this)
287
288 lateinit var upstream: NodeConnection<Map<K, Maybe<TFlowImpl<V>>>>
289
schedulenull290 override suspend fun schedule(evalScope: EvalScope) {
291 val upstreamResult = upstream.getPushEvent(evalScope)
292 if (upstreamResult is Just) {
293 muxNode.patchData = upstreamResult.value
294 evalScope.schedule(muxNode)
295 }
296 }
297
adjustDirectUpstreamnull298 override suspend fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
299 muxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth)
300 }
301
moveIndirectUpstreamToDirectnull302 override suspend fun moveIndirectUpstreamToDirect(
303 scheduler: Scheduler,
304 oldIndirectDepth: Int,
305 oldIndirectSet: Set<MuxDeferredNode<*, *>>,
306 newDirectDepth: Int,
307 ) {
308 muxNode.moveIndirectUpstreamToDirect(
309 scheduler,
310 oldIndirectDepth,
311 oldIndirectSet,
312 newDirectDepth,
313 )
314 }
315
adjustIndirectUpstreamnull316 override suspend fun adjustIndirectUpstream(
317 scheduler: Scheduler,
318 oldDepth: Int,
319 newDepth: Int,
320 removals: Set<MuxDeferredNode<*, *>>,
321 additions: Set<MuxDeferredNode<*, *>>,
322 ) {
323 muxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions)
324 }
325
moveDirectUpstreamToIndirectnull326 override suspend fun moveDirectUpstreamToIndirect(
327 scheduler: Scheduler,
328 oldDirectDepth: Int,
329 newIndirectDepth: Int,
330 newIndirectSet: Set<MuxDeferredNode<*, *>>,
331 ) {
332 muxNode.moveDirectUpstreamToIndirect(
333 scheduler,
334 oldDirectDepth,
335 newIndirectDepth,
336 newIndirectSet,
337 )
338 }
339
removeDirectUpstreamnull340 override suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) {
341 muxNode.removeDirectPatchNode(scheduler, depth)
342 }
343
removeIndirectUpstreamnull344 override suspend fun removeIndirectUpstream(
345 scheduler: Scheduler,
346 depth: Int,
347 indirectSet: Set<MuxDeferredNode<*, *>>,
348 ) {
349 muxNode.removeIndirectPatchNode(scheduler, depth, indirectSet)
350 }
351 }
352
switchPromptImplnull353 internal fun <K : Any, A> switchPromptImpl(
354 getStorage: suspend EvalScope.() -> Map<K, TFlowImpl<A>>,
355 getPatches: suspend EvalScope.() -> TFlowImpl<Map<K, Maybe<TFlowImpl<A>>>>,
356 ): TFlowImpl<Map<K, A>> {
357 val moving =
358 MuxLifecycle(
359 object : MuxActivator<Pair<Map<K, A>, Map<K, PullNode<A>>?>> {
360 override suspend fun activate(
361 evalScope: EvalScope,
362 lifecycle: MuxLifecycle<Pair<Map<K, A>, Map<K, PullNode<A>>?>>,
363 ): MuxNode<*, *, Pair<Map<K, A>, Map<K, PullNode<A>>?>>? {
364 val storage: Map<K, TFlowImpl<A>> = getStorage(evalScope)
365 // Initialize mux node and switched-in connections.
366 val movingNode =
367 MuxPromptMovingNode(lifecycle, this).apply {
368 coroutineScope {
369 launch {
370 storage.mapValuesNotNullParallelTo(switchedIn) { (key, flow) ->
371 val branchNode = MuxBranchNode(this@apply, key)
372 flow
373 .activate(
374 evalScope = evalScope,
375 downstream = branchNode.schedulable,
376 )
377 ?.let { (conn, needsEval) ->
378 branchNode
379 .apply { upstream = conn }
380 .also {
381 if (needsEval) {
382 val result =
383 conn.getPushEvent(evalScope)
384 if (result is Just) {
385 upstreamData[key] = result.value
386 }
387 }
388 }
389 }
390 }
391 }
392 // Setup patches connection
393 val patchNode = MuxPromptPatchNode(this@apply)
394 getPatches(evalScope)
395 .activate(
396 evalScope = evalScope,
397 downstream = patchNode.schedulable,
398 )
399 ?.let { (conn, needsEval) ->
400 patchNode.upstream = conn
401 patches = patchNode
402
403 if (needsEval) {
404 val result = conn.getPushEvent(evalScope)
405 if (result is Just) {
406 patchData = result.value
407 }
408 }
409 }
410 }
411 }
412 // Update depth based on all initial switched-in nodes.
413 movingNode.switchedIn.values.forEach { branch ->
414 val conn = branch.upstream
415 if (conn.depthTracker.snapshotIsDirect) {
416 movingNode.depthTracker.addDirectUpstream(
417 oldDepth = null,
418 newDepth = conn.depthTracker.snapshotDirectDepth,
419 )
420 } else {
421 movingNode.depthTracker.addIndirectUpstream(
422 oldDepth = null,
423 newDepth = conn.depthTracker.snapshotIndirectDepth,
424 )
425 movingNode.depthTracker.updateIndirectRoots(
426 additions = conn.depthTracker.snapshotIndirectRoots,
427 butNot = null,
428 )
429 }
430 }
431 // Update depth based on patches node.
432 movingNode.patches?.upstream?.let { conn ->
433 if (conn.depthTracker.snapshotIsDirect) {
434 movingNode.depthTracker.addDirectUpstream(
435 oldDepth = null,
436 newDepth = conn.depthTracker.snapshotDirectDepth,
437 )
438 } else {
439 movingNode.depthTracker.addIndirectUpstream(
440 oldDepth = null,
441 newDepth = conn.depthTracker.snapshotIndirectDepth,
442 )
443 movingNode.depthTracker.updateIndirectRoots(
444 additions = conn.depthTracker.snapshotIndirectRoots,
445 butNot = null,
446 )
447 }
448 }
449 movingNode.depthTracker.reset()
450
451 // Schedule for evaluation if any switched-in nodes or the patches node have
452 // already emitted within this transaction.
453 if (movingNode.patchData != null || movingNode.upstreamData.isNotEmpty()) {
454 evalScope.schedule(movingNode)
455 }
456
457 return movingNode.takeUnless { it.patches == null && it.switchedIn.isEmpty() }
458 }
459 }
460 )
461
462 val eval = TFlowCheap { downstream ->
463 moving.activate(evalScope = this, downstream)?.let { (connection, needsEval) ->
464 val evalNode = MuxPromptEvalNode(connection.directUpstream)
465 ActivationResult(
466 connection = NodeConnection(evalNode, connection.schedulerUpstream),
467 needsEval = needsEval,
468 )
469 }
470 }
471 return eval.cached()
472 }
473