1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.systemui.kairos.internal
18
19 import com.android.systemui.kairos.internal.util.Key
20 import com.android.systemui.kairos.internal.util.associateByIndexTo
21 import com.android.systemui.kairos.internal.util.hashString
22 import com.android.systemui.kairos.internal.util.mapParallel
23 import com.android.systemui.kairos.internal.util.mapValuesNotNullParallelTo
24 import com.android.systemui.kairos.util.Just
25 import com.android.systemui.kairos.util.Left
26 import com.android.systemui.kairos.util.Maybe
27 import com.android.systemui.kairos.util.None
28 import com.android.systemui.kairos.util.Right
29 import com.android.systemui.kairos.util.These
30 import com.android.systemui.kairos.util.flatMap
31 import com.android.systemui.kairos.util.getMaybe
32 import com.android.systemui.kairos.util.just
33 import com.android.systemui.kairos.util.maybeThat
34 import com.android.systemui.kairos.util.maybeThis
35 import com.android.systemui.kairos.util.merge
36 import com.android.systemui.kairos.util.orElseGet
37 import com.android.systemui.kairos.util.partitionEithers
38 import com.android.systemui.kairos.util.these
39 import java.util.TreeMap
40 import kotlinx.coroutines.coroutineScope
41 import kotlinx.coroutines.launch
42 import kotlinx.coroutines.sync.withLock
43
44 internal class MuxDeferredNode<K : Any, V>(
45 lifecycle: MuxLifecycle<Map<K, V>>,
46 val spec: MuxActivator<Map<K, V>>,
47 ) : MuxNode<K, V, Map<K, V>>(lifecycle), Key<Map<K, V>> {
48
49 val schedulable = Schedulable.M(this)
50
51 @Volatile var patches: NodeConnection<Map<K, Maybe<TFlowImpl<V>>>>? = null
52 @Volatile var patchData: Map<K, Maybe<TFlowImpl<V>>>? = null
53
54 override fun hasCurrentValueLocked(transactionStore: TransactionStore): Boolean =
55 transactionStore.contains(this)
56
57 override suspend fun hasCurrentValue(transactionStore: TransactionStore): Boolean =
58 mutex.withLock { hasCurrentValueLocked(transactionStore) }
59
60 override suspend fun visit(evalScope: EvalScope) {
61 val result = upstreamData.toMap()
62 upstreamData.clear()
63 val scheduleDownstream = result.isNotEmpty()
64 val compactDownstream = depthTracker.isDirty()
65 if (scheduleDownstream || compactDownstream) {
66 coroutineScope {
67 mutex.withLock {
68 if (compactDownstream) {
69 depthTracker.applyChanges(
70 coroutineScope = this,
71 evalScope.scheduler,
72 downstreamSet,
73 muxNode = this@MuxDeferredNode,
74 )
75 }
76 if (scheduleDownstream) {
77 evalScope.setResult(this@MuxDeferredNode, result)
78 if (!scheduleAll(downstreamSet, evalScope)) {
79 evalScope.scheduleDeactivation(this@MuxDeferredNode)
80 }
81 }
82 }
83 }
84 }
85 }
86
87 override suspend fun getPushEvent(evalScope: EvalScope): Maybe<Map<K, V>> =
88 evalScope.getCurrentValue(key = this)
89
90 private suspend fun compactIfNeeded(evalScope: EvalScope) {
91 depthTracker.propagateChanges(evalScope.compactor, this)
92 }
93
94 override suspend fun doDeactivate() {
95 // Update lifecycle
96 lifecycle.mutex.withLock {
97 if (lifecycle.lifecycleState !is MuxLifecycleState.Active) return@doDeactivate
98 lifecycle.lifecycleState = MuxLifecycleState.Inactive(spec)
99 }
100 // Process branch nodes
101 coroutineScope {
102 switchedIn.values.forEach { branchNode ->
103 branchNode.upstream.let {
104 launch { it.removeDownstreamAndDeactivateIfNeeded(branchNode.schedulable) }
105 }
106 }
107 }
108 // Process patch node
109 patches?.removeDownstreamAndDeactivateIfNeeded(schedulable)
110 }
111
112 // MOVE phase
113 // - concurrent moves may be occurring, but no more evals. all depth recalculations are
114 // deferred to the end of this phase.
115 suspend fun performMove(evalScope: EvalScope) {
116 val patch = patchData ?: return
117 patchData = null
118
119 // TODO: this logic is very similar to what's in MuxPromptMoving, maybe turn into an inline
120 // fun?
121
122 // We have a patch, process additions/updates and removals
123 val (adds, removes) =
124 patch
125 .asSequence()
126 .map { (k, newUpstream: Maybe<TFlowImpl<V>>) ->
127 when (newUpstream) {
128 is Just -> Left(k to newUpstream.value)
129 None -> Right(k)
130 }
131 }
132 .partitionEithers()
133
134 val severed = mutableListOf<NodeConnection<*>>()
135
136 coroutineScope {
137 // remove and sever
138 removes.forEach { k ->
139 switchedIn.remove(k)?.let { branchNode: MuxBranchNode<K, V> ->
140 val conn = branchNode.upstream
141 severed.add(conn)
142 launch { conn.removeDownstream(downstream = branchNode.schedulable) }
143 depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
144 }
145 }
146
147 // add or replace
148 adds
149 .mapParallel { (k, newUpstream: TFlowImpl<V>) ->
150 val branchNode = MuxBranchNode(this@MuxDeferredNode, k)
151 k to
152 newUpstream.activate(evalScope, branchNode.schedulable)?.let { (conn, _) ->
153 branchNode.apply { upstream = conn }
154 }
155 }
156 .forEach { (k, newBranch: MuxBranchNode<K, V>?) ->
157 // remove old and sever, if present
158 switchedIn.remove(k)?.let { branchNode ->
159 val conn = branchNode.upstream
160 severed.add(conn)
161 launch { conn.removeDownstream(downstream = branchNode.schedulable) }
162 depthTracker.removeDirectUpstream(conn.depthTracker.snapshotDirectDepth)
163 }
164
165 // add new
166 newBranch?.let {
167 switchedIn[k] = newBranch
168 val branchDepthTracker = newBranch.upstream.depthTracker
169 if (branchDepthTracker.snapshotIsDirect) {
170 depthTracker.addDirectUpstream(
171 oldDepth = null,
172 newDepth = branchDepthTracker.snapshotDirectDepth,
173 )
174 } else {
175 depthTracker.addIndirectUpstream(
176 oldDepth = null,
177 newDepth = branchDepthTracker.snapshotIndirectDepth,
178 )
179 depthTracker.updateIndirectRoots(
180 additions = branchDepthTracker.snapshotIndirectRoots,
181 butNot = this@MuxDeferredNode,
182 )
183 }
184 }
185 }
186 }
187
188 coroutineScope {
189 for (severedNode in severed) {
190 launch { severedNode.scheduleDeactivationIfNeeded(evalScope) }
191 }
192 }
193
194 compactIfNeeded(evalScope)
195 }
196
197 suspend fun removeDirectPatchNode(scheduler: Scheduler) {
198 mutex.withLock {
199 if (
200 depthTracker.removeIndirectUpstream(depth = 0) or
201 depthTracker.setIsIndirectRoot(false)
202 ) {
203 depthTracker.schedule(scheduler, this)
204 }
205 patches = null
206 }
207 }
208
209 suspend fun removeIndirectPatchNode(
210 scheduler: Scheduler,
211 depth: Int,
212 indirectSet: Set<MuxDeferredNode<*, *>>,
213 ) {
214 // indirectly connected patches forward the indirectSet
215 mutex.withLock {
216 if (
217 depthTracker.updateIndirectRoots(removals = indirectSet) or
218 depthTracker.removeIndirectUpstream(depth)
219 ) {
220 depthTracker.schedule(scheduler, this)
221 }
222 patches = null
223 }
224 }
225
226 suspend fun moveIndirectPatchNodeToDirect(
227 scheduler: Scheduler,
228 oldIndirectDepth: Int,
229 oldIndirectSet: Set<MuxDeferredNode<*, *>>,
230 ) {
231 // directly connected patches are stored as an indirect singleton set of the patchNode
232 mutex.withLock {
233 if (
234 depthTracker.updateIndirectRoots(removals = oldIndirectSet) or
235 depthTracker.removeIndirectUpstream(oldIndirectDepth) or
236 depthTracker.setIsIndirectRoot(true)
237 ) {
238 depthTracker.schedule(scheduler, this)
239 }
240 }
241 }
242
243 suspend fun moveDirectPatchNodeToIndirect(
244 scheduler: Scheduler,
245 newIndirectDepth: Int,
246 newIndirectSet: Set<MuxDeferredNode<*, *>>,
247 ) {
248 // indirectly connected patches forward the indirectSet
249 mutex.withLock {
250 if (
251 depthTracker.setIsIndirectRoot(false) or
252 depthTracker.updateIndirectRoots(additions = newIndirectSet, butNot = this) or
253 depthTracker.addIndirectUpstream(oldDepth = null, newDepth = newIndirectDepth)
254 ) {
255 depthTracker.schedule(scheduler, this)
256 }
257 }
258 }
259
260 suspend fun adjustIndirectPatchNode(
261 scheduler: Scheduler,
262 oldDepth: Int,
263 newDepth: Int,
264 removals: Set<MuxDeferredNode<*, *>>,
265 additions: Set<MuxDeferredNode<*, *>>,
266 ) {
267 // indirectly connected patches forward the indirectSet
268 mutex.withLock {
269 if (
270 depthTracker.updateIndirectRoots(
271 additions = additions,
272 removals = removals,
273 butNot = this,
274 ) or depthTracker.addIndirectUpstream(oldDepth = oldDepth, newDepth = newDepth)
275 ) {
276 depthTracker.schedule(scheduler, this)
277 }
278 }
279 }
280
281 suspend fun scheduleMover(evalScope: EvalScope) {
282 patchData =
283 checkNotNull(patches) { "mux mover scheduled with unset patches upstream node" }
284 .getPushEvent(evalScope)
285 .orElseGet { null }
286 evalScope.scheduleMuxMover(this)
287 }
288
289 override fun toString(): String = "${this::class.simpleName}@$hashString"
290 }
291
switchDeferredImplSinglenull292 internal inline fun <A> switchDeferredImplSingle(
293 crossinline getStorage: suspend EvalScope.() -> TFlowImpl<A>,
294 crossinline getPatches: suspend EvalScope.() -> TFlowImpl<TFlowImpl<A>>,
295 ): TFlowImpl<A> =
296 mapImpl({
297 switchDeferredImpl(
298 getStorage = { mapOf(Unit to getStorage()) },
299 getPatches = { mapImpl(getPatches) { newFlow -> mapOf(Unit to just(newFlow)) } },
300 )
301 }) { map ->
302 map.getValue(Unit)
303 }
304
switchDeferredImplnull305 internal fun <K : Any, A> switchDeferredImpl(
306 getStorage: suspend EvalScope.() -> Map<K, TFlowImpl<A>>,
307 getPatches: suspend EvalScope.() -> TFlowImpl<Map<K, Maybe<TFlowImpl<A>>>>,
308 ): TFlowImpl<Map<K, A>> =
309 MuxLifecycle(
310 object : MuxActivator<Map<K, A>> {
311 override suspend fun activate(
312 evalScope: EvalScope,
313 lifecycle: MuxLifecycle<Map<K, A>>,
314 ): MuxNode<*, *, Map<K, A>>? {
315 val storage: Map<K, TFlowImpl<A>> = getStorage(evalScope)
316 // Initialize mux node and switched-in connections.
317 val muxNode =
318 MuxDeferredNode(lifecycle, this).apply {
319 storage.mapValuesNotNullParallelTo(switchedIn) { (key, flow) ->
320 val branchNode = MuxBranchNode(this@apply, key)
321 flow.activate(evalScope, branchNode.schedulable)?.let {
322 (conn, needsEval) ->
323 branchNode
324 .apply { upstream = conn }
325 .also {
326 if (needsEval) {
327 val result = conn.getPushEvent(evalScope)
328 if (result is Just) {
329 upstreamData[key] = result.value
330 }
331 }
332 }
333 }
334 }
335 }
336 // Update depth based on all initial switched-in nodes.
337 muxNode.switchedIn.values.forEach { branch ->
338 val conn = branch.upstream
339 if (conn.depthTracker.snapshotIsDirect) {
340 muxNode.depthTracker.addDirectUpstream(
341 oldDepth = null,
342 newDepth = conn.depthTracker.snapshotDirectDepth,
343 )
344 } else {
345 muxNode.depthTracker.addIndirectUpstream(
346 oldDepth = null,
347 newDepth = conn.depthTracker.snapshotIndirectDepth,
348 )
349 muxNode.depthTracker.updateIndirectRoots(
350 additions = conn.depthTracker.snapshotIndirectRoots,
351 butNot = muxNode,
352 )
353 }
354 }
355 // We don't have our patches connection established yet, so for now pretend we have
356 // a direct connection to patches. We will update downstream nodes later if this
357 // turns out to be a lie.
358 muxNode.depthTracker.setIsIndirectRoot(true)
359 muxNode.depthTracker.reset()
360
361 // Setup patches connection; deferring allows for a recursive connection, where
362 // muxNode is downstream of itself via patches.
363 var isIndirect = true
364 evalScope.deferAction {
365 val (patchesConn, needsEval) =
366 getPatches(evalScope).activate(evalScope, downstream = muxNode.schedulable)
367 ?: run {
368 isIndirect = false
369 // Turns out we can't connect to patches, so update our depth and
370 // propagate
371 muxNode.mutex.withLock {
372 if (muxNode.depthTracker.setIsIndirectRoot(false)) {
373 muxNode.depthTracker.schedule(evalScope.scheduler, muxNode)
374 }
375 }
376 return@deferAction
377 }
378 muxNode.patches = patchesConn
379
380 if (!patchesConn.schedulerUpstream.depthTracker.snapshotIsDirect) {
381 // Turns out patches is indirect, so we are not a root. Update depth and
382 // propagate.
383 muxNode.mutex.withLock {
384 if (
385 muxNode.depthTracker.setIsIndirectRoot(false) or
386 muxNode.depthTracker.addIndirectUpstream(
387 oldDepth = null,
388 newDepth = patchesConn.depthTracker.snapshotIndirectDepth,
389 ) or
390 muxNode.depthTracker.updateIndirectRoots(
391 additions = patchesConn.depthTracker.snapshotIndirectRoots
392 )
393 ) {
394 muxNode.depthTracker.schedule(evalScope.scheduler, muxNode)
395 }
396 }
397 }
398 // Schedule mover to process patch emission at the end of this transaction, if
399 // needed.
400 if (needsEval) {
401 val result = patchesConn.getPushEvent(evalScope)
402 if (result is Just) {
403 muxNode.patchData = result.value
404 evalScope.scheduleMuxMover(muxNode)
405 }
406 }
407 }
408
409 // Schedule for evaluation if any switched-in nodes have already emitted within
410 // this transaction.
411 if (muxNode.upstreamData.isNotEmpty()) {
412 evalScope.schedule(muxNode)
413 }
414 return muxNode.takeUnless { muxNode.switchedIn.isEmpty() && !isIndirect }
415 }
416 }
417 )
418
mergeNodesnull419 internal inline fun <A> mergeNodes(
420 crossinline getPulse: suspend EvalScope.() -> TFlowImpl<A>,
421 crossinline getOther: suspend EvalScope.() -> TFlowImpl<A>,
422 crossinline f: suspend EvalScope.(A, A) -> A,
423 ): TFlowImpl<A> {
424 val merged =
425 mapImpl({ mergeNodes(getPulse, getOther) }) { these ->
426 these.merge { thiz, that -> f(thiz, that) }
427 }
428 return merged.cached()
429 }
430
mergeNodesnull431 internal inline fun <A, B> mergeNodes(
432 crossinline getPulse: suspend EvalScope.() -> TFlowImpl<A>,
433 crossinline getOther: suspend EvalScope.() -> TFlowImpl<B>,
434 ): TFlowImpl<These<A, B>> {
435 val storage =
436 mapOf(
437 0 to mapImpl(getPulse) { These.thiz<A, B>(it) },
438 1 to mapImpl(getOther) { These.that(it) },
439 )
440 val switchNode = switchDeferredImpl(getStorage = { storage }, getPatches = { neverImpl })
441 val merged =
442 mapImpl({ switchNode }) { mergeResults ->
443 val first = mergeResults.getMaybe(0).flatMap { it.maybeThis() }
444 val second = mergeResults.getMaybe(1).flatMap { it.maybeThat() }
445 these(first, second).orElseGet { error("unexpected missing merge result") }
446 }
447 return merged.cached()
448 }
449
mergeNodesnull450 internal inline fun <A> mergeNodes(
451 crossinline getPulses: suspend EvalScope.() -> Iterable<TFlowImpl<A>>
452 ): TFlowImpl<List<A>> {
453 val switchNode =
454 switchDeferredImpl(
455 getStorage = { getPulses().associateByIndexTo(TreeMap()) },
456 getPatches = { neverImpl },
457 )
458 val merged = mapImpl({ switchNode }) { mergeResults -> mergeResults.values.toList() }
459 return merged.cached()
460 }
461
mergeNodesLeftnull462 internal inline fun <A> mergeNodesLeft(
463 crossinline getPulses: suspend EvalScope.() -> Iterable<TFlowImpl<A>>
464 ): TFlowImpl<A> {
465 val switchNode =
466 switchDeferredImpl(
467 getStorage = { getPulses().associateByIndexTo(TreeMap()) },
468 getPatches = { neverImpl },
469 )
470 val merged =
471 mapImpl({ switchNode }) { mergeResults: Map<Int, A> -> mergeResults.values.first() }
472 return merged.cached()
473 }
474