xref: /aosp_15_r20/frameworks/base/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/Mux.kt (revision d57664e9bc4670b3ecf6748a746a57c557b6bc9e)
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:Suppress("NOTHING_TO_INLINE")
18 
19 package com.android.systemui.kairos.internal
20 
21 import com.android.systemui.kairos.internal.util.ConcurrentNullableHashMap
22 import com.android.systemui.kairos.internal.util.hashString
23 import com.android.systemui.kairos.util.Just
24 import java.util.concurrent.ConcurrentHashMap
25 import kotlinx.coroutines.coroutineScope
26 import kotlinx.coroutines.sync.Mutex
27 import kotlinx.coroutines.sync.withLock
28 
29 /** Base class for muxing nodes, which have a potentially dynamic collection of upstream nodes. */
30 internal sealed class MuxNode<K : Any, V, Output>(val lifecycle: MuxLifecycle<Output>) :
31     PushNode<Output> {
32 
33     inline val mutex
34         get() = lifecycle.mutex
35 
36     // TODO: preserve insertion order?
37     val upstreamData = ConcurrentNullableHashMap<K, V>()
38     val switchedIn = ConcurrentHashMap<K, MuxBranchNode<K, V>>()
39     val downstreamSet: DownstreamSet = DownstreamSet()
40 
41     // TODO: inline DepthTracker? would need to be added to PushNode signature
42     final override val depthTracker = DepthTracker()
43 
44     final override suspend fun addDownstream(downstream: Schedulable) {
45         mutex.withLock { addDownstreamLocked(downstream) }
46     }
47 
48     /**
49      * Adds a downstream schedulable to this mux node, such that when this mux node emits a value,
50      * it will be scheduled for evaluation within this same transaction.
51      *
52      * Must only be called when [mutex] is acquired.
53      */
54     fun addDownstreamLocked(downstream: Schedulable) {
55         downstreamSet.add(downstream)
56     }
57 
58     final override suspend fun removeDownstream(downstream: Schedulable) {
59         // TODO: return boolean?
60         mutex.withLock { downstreamSet.remove(downstream) }
61     }
62 
63     final override suspend fun removeDownstreamAndDeactivateIfNeeded(downstream: Schedulable) {
64         val deactivate =
65             mutex.withLock {
66                 downstreamSet.remove(downstream)
67                 downstreamSet.isEmpty()
68             }
69         if (deactivate) {
70             doDeactivate()
71         }
72     }
73 
74     final override suspend fun deactivateIfNeeded() {
75         if (mutex.withLock { downstreamSet.isEmpty() }) {
76             doDeactivate()
77         }
78     }
79 
80     /** visit this node from the scheduler (push eval) */
81     abstract suspend fun visit(evalScope: EvalScope)
82 
83     /** perform deactivation logic, propagating to all upstream nodes. */
84     protected abstract suspend fun doDeactivate()
85 
86     final override suspend fun scheduleDeactivationIfNeeded(evalScope: EvalScope) {
87         if (mutex.withLock { downstreamSet.isEmpty() }) {
88             evalScope.scheduleDeactivation(this)
89         }
90     }
91 
92     suspend fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
93         mutex.withLock {
94             if (depthTracker.addDirectUpstream(oldDepth, newDepth)) {
95                 depthTracker.schedule(scheduler, this)
96             }
97         }
98     }
99 
100     suspend fun moveIndirectUpstreamToDirect(
101         scheduler: Scheduler,
102         oldIndirectDepth: Int,
103         oldIndirectRoots: Set<MuxDeferredNode<*, *>>,
104         newDepth: Int,
105     ) {
106         mutex.withLock {
107             if (
108                 depthTracker.addDirectUpstream(oldDepth = null, newDepth) or
109                     depthTracker.removeIndirectUpstream(depth = oldIndirectDepth) or
110                     depthTracker.updateIndirectRoots(removals = oldIndirectRoots)
111             ) {
112                 depthTracker.schedule(scheduler, this)
113             }
114         }
115     }
116 
117     suspend fun adjustIndirectUpstream(
118         scheduler: Scheduler,
119         oldDepth: Int,
120         newDepth: Int,
121         removals: Set<MuxDeferredNode<*, *>>,
122         additions: Set<MuxDeferredNode<*, *>>,
123     ) {
124         mutex.withLock {
125             if (
126                 depthTracker.addIndirectUpstream(oldDepth, newDepth) or
127                     depthTracker.updateIndirectRoots(
128                         additions,
129                         removals,
130                         butNot = this as? MuxDeferredNode<*, *>,
131                     )
132             ) {
133                 depthTracker.schedule(scheduler, this)
134             }
135         }
136     }
137 
138     suspend fun moveDirectUpstreamToIndirect(
139         scheduler: Scheduler,
140         oldDepth: Int,
141         newDepth: Int,
142         newIndirectSet: Set<MuxDeferredNode<*, *>>,
143     ) {
144         mutex.withLock {
145             if (
146                 depthTracker.addIndirectUpstream(oldDepth = null, newDepth) or
147                     depthTracker.removeDirectUpstream(oldDepth) or
148                     depthTracker.updateIndirectRoots(
149                         additions = newIndirectSet,
150                         butNot = this as? MuxDeferredNode<*, *>,
151                     )
152             ) {
153                 depthTracker.schedule(scheduler, this)
154             }
155         }
156     }
157 
158     suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int, key: K) {
159         mutex.withLock {
160             switchedIn.remove(key)
161             if (depthTracker.removeDirectUpstream(depth)) {
162                 depthTracker.schedule(scheduler, this)
163             }
164         }
165     }
166 
167     suspend fun removeIndirectUpstream(
168         scheduler: Scheduler,
169         oldDepth: Int,
170         indirectSet: Set<MuxDeferredNode<*, *>>,
171         key: K,
172     ) {
173         mutex.withLock {
174             switchedIn.remove(key)
175             if (
176                 depthTracker.removeIndirectUpstream(oldDepth) or
177                     depthTracker.updateIndirectRoots(removals = indirectSet)
178             ) {
179                 depthTracker.schedule(scheduler, this)
180             }
181         }
182     }
183 
184     suspend fun visitCompact(scheduler: Scheduler) = coroutineScope {
185         if (depthTracker.isDirty()) {
186             depthTracker.applyChanges(coroutineScope = this, scheduler, downstreamSet, this@MuxNode)
187         }
188     }
189 
190     abstract fun hasCurrentValueLocked(transactionStore: TransactionStore): Boolean
191 }
192 
193 /** An input branch of a mux node, associated with a key. */
194 internal class MuxBranchNode<K : Any, V>(private val muxNode: MuxNode<K, V, *>, val key: K) :
195     SchedulableNode {
196 
197     val schedulable = Schedulable.N(this)
198 
199     @Volatile lateinit var upstream: NodeConnection<V>
200 
schedulenull201     override suspend fun schedule(evalScope: EvalScope) {
202         val upstreamResult = upstream.getPushEvent(evalScope)
203         if (upstreamResult is Just) {
204             muxNode.upstreamData[key] = upstreamResult.value
205             evalScope.schedule(muxNode)
206         }
207     }
208 
adjustDirectUpstreamnull209     override suspend fun adjustDirectUpstream(scheduler: Scheduler, oldDepth: Int, newDepth: Int) {
210         muxNode.adjustDirectUpstream(scheduler, oldDepth, newDepth)
211     }
212 
moveIndirectUpstreamToDirectnull213     override suspend fun moveIndirectUpstreamToDirect(
214         scheduler: Scheduler,
215         oldIndirectDepth: Int,
216         oldIndirectSet: Set<MuxDeferredNode<*, *>>,
217         newDirectDepth: Int,
218     ) {
219         muxNode.moveIndirectUpstreamToDirect(
220             scheduler,
221             oldIndirectDepth,
222             oldIndirectSet,
223             newDirectDepth,
224         )
225     }
226 
adjustIndirectUpstreamnull227     override suspend fun adjustIndirectUpstream(
228         scheduler: Scheduler,
229         oldDepth: Int,
230         newDepth: Int,
231         removals: Set<MuxDeferredNode<*, *>>,
232         additions: Set<MuxDeferredNode<*, *>>,
233     ) {
234         muxNode.adjustIndirectUpstream(scheduler, oldDepth, newDepth, removals, additions)
235     }
236 
moveDirectUpstreamToIndirectnull237     override suspend fun moveDirectUpstreamToIndirect(
238         scheduler: Scheduler,
239         oldDirectDepth: Int,
240         newIndirectDepth: Int,
241         newIndirectSet: Set<MuxDeferredNode<*, *>>,
242     ) {
243         muxNode.moveDirectUpstreamToIndirect(
244             scheduler,
245             oldDirectDepth,
246             newIndirectDepth,
247             newIndirectSet,
248         )
249     }
250 
removeDirectUpstreamnull251     override suspend fun removeDirectUpstream(scheduler: Scheduler, depth: Int) {
252         muxNode.removeDirectUpstream(scheduler, depth, key)
253     }
254 
removeIndirectUpstreamnull255     override suspend fun removeIndirectUpstream(
256         scheduler: Scheduler,
257         depth: Int,
258         indirectSet: Set<MuxDeferredNode<*, *>>,
259     ) {
260         muxNode.removeIndirectUpstream(scheduler, depth, indirectSet, key)
261     }
262 
toStringnull263     override fun toString(): String = "MuxBranchNode(key=$key, mux=$muxNode)"
264 }
265 
266 /** Tracks lifecycle of MuxNode in the network. Essentially a mutable ref for MuxLifecycleState. */
267 internal class MuxLifecycle<A>(@Volatile var lifecycleState: MuxLifecycleState<A>) : TFlowImpl<A> {
268     val mutex = Mutex()
269 
270     override fun toString(): String = "TFlowLifecycle[$hashString][$lifecycleState][$mutex]"
271 
272     override suspend fun activate(
273         evalScope: EvalScope,
274         downstream: Schedulable,
275     ): ActivationResult<A>? =
276         mutex.withLock {
277             when (val state = lifecycleState) {
278                 is MuxLifecycleState.Dead -> null
279                 is MuxLifecycleState.Active -> {
280                     state.node.addDownstreamLocked(downstream)
281                     ActivationResult(
282                         connection = NodeConnection(state.node, state.node),
283                         needsEval = state.node.hasCurrentValueLocked(evalScope.transactionStore),
284                     )
285                 }
286                 is MuxLifecycleState.Inactive -> {
287                     state.spec
288                         .activate(evalScope, this@MuxLifecycle)
289                         .also { node ->
290                             lifecycleState =
291                                 if (node == null) {
292                                     MuxLifecycleState.Dead
293                                 } else {
294                                     MuxLifecycleState.Active(node)
295                                 }
296                         }
297                         ?.let { node ->
298                             node.addDownstreamLocked(downstream)
299                             ActivationResult(
300                                 connection = NodeConnection(node, node),
301                                 needsEval = false,
302                             )
303                         }
304                 }
305             }
306         }
307 }
308 
309 internal sealed interface MuxLifecycleState<out A> {
310     class Inactive<A>(val spec: MuxActivator<A>) : MuxLifecycleState<A> {
toStringnull311         override fun toString(): String = "Inactive"
312     }
313 
314     class Active<A>(val node: MuxNode<*, *, A>) : MuxLifecycleState<A> {
315         override fun toString(): String = "Active(node=$node)"
316     }
317 
318     data object Dead : MuxLifecycleState<Nothing>
319 }
320 
321 internal interface MuxActivator<A> {
activatenull322     suspend fun activate(evalScope: EvalScope, lifecycle: MuxLifecycle<A>): MuxNode<*, *, A>?
323 }
324 
325 internal inline fun <A> MuxLifecycle(onSubscribe: MuxActivator<A>): TFlowImpl<A> =
326     MuxLifecycle(MuxLifecycleState.Inactive(onSubscribe))
327