xref: /aosp_15_r20/frameworks/base/packages/SystemUI/utils/kairos/src/com/android/systemui/kairos/internal/TStateImpl.kt (revision d57664e9bc4670b3ecf6748a746a57c557b6bc9e)
1 /*
<lambda>null2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.systemui.kairos.internal
18 
19 import com.android.systemui.kairos.internal.util.Key
20 import com.android.systemui.kairos.internal.util.associateByIndex
21 import com.android.systemui.kairos.internal.util.hashString
22 import com.android.systemui.kairos.internal.util.mapValuesParallel
23 import com.android.systemui.kairos.util.Just
24 import com.android.systemui.kairos.util.Maybe
25 import com.android.systemui.kairos.util.just
26 import com.android.systemui.kairos.util.none
27 import java.util.concurrent.atomic.AtomicLong
28 import kotlinx.coroutines.CompletableDeferred
29 import kotlinx.coroutines.CoroutineStart
30 import kotlinx.coroutines.Deferred
31 import kotlinx.coroutines.ExperimentalCoroutinesApi
32 
33 internal sealed interface TStateImpl<out A> {
34     val name: String?
35     val operatorName: String
36     val changes: TFlowImpl<A>
37 
38     suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long>
39 }
40 
41 internal sealed class TStateDerived<A>(override val changes: TFlowImpl<A>) :
42     TStateImpl<A>, Key<Deferred<Pair<A, Long>>> {
43 
44     @Volatile
45     var invalidatedEpoch = Long.MIN_VALUE
46         private set
47 
48     @Volatile
49     protected var cache: Any? = EmptyCache
50         private set
51 
getCurrentWithEpochnull52     override suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
53         evalScope.transactionStore
54             .getOrPut(this) { evalScope.deferAsync(CoroutineStart.LAZY) { pull(evalScope) } }
55             .await()
56 
pullnull57     suspend fun pull(evalScope: EvalScope): Pair<A, Long> {
58         @Suppress("UNCHECKED_CAST")
59         return recalc(evalScope)?.also { (a, epoch) -> setCache(a, epoch) }
60             ?: ((cache as A) to invalidatedEpoch)
61     }
62 
setCachenull63     fun setCache(value: A, epoch: Long) {
64         if (epoch > invalidatedEpoch) {
65             cache = value
66             invalidatedEpoch = epoch
67         }
68     }
69 
getCachedUnsafenull70     fun getCachedUnsafe(): Maybe<A> {
71         @Suppress("UNCHECKED_CAST")
72         return if (cache == EmptyCache) none else just(cache as A)
73     }
74 
recalcnull75     protected abstract suspend fun recalc(evalScope: EvalScope): Pair<A, Long>?
76 
77     private data object EmptyCache
78 }
79 
80 internal class TStateSource<A>(
81     override val name: String?,
82     override val operatorName: String,
83     init: Deferred<A>,
84     override val changes: TFlowImpl<A>,
85 ) : TStateImpl<A> {
86     constructor(
87         name: String?,
88         operatorName: String,
89         init: A,
90         changes: TFlowImpl<A>,
91     ) : this(name, operatorName, CompletableDeferred(init), changes)
92 
93     lateinit var upstreamConnection: NodeConnection<A>
94 
95     // Note: Don't need to synchronize; we will never interleave reads and writes, since all writes
96     // are performed at the end of a network step, after any reads would have taken place.
97 
98     @Volatile private var _current: Deferred<A> = init
99     @Volatile
100     var writeEpoch = 0L
101         private set
102 
103     override suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
104         _current.await() to writeEpoch
105 
106     /** called by network after eval phase has completed */
107     suspend fun updateState(evalScope: EvalScope) {
108         // write the latch
109         val eventResult = upstreamConnection.getPushEvent(evalScope)
110         if (eventResult is Just) {
111             _current = CompletableDeferred(eventResult.value)
112             writeEpoch = evalScope.epoch
113         }
114     }
115 
116     override fun toString(): String = "TStateImpl(changes=$changes, current=$_current)"
117 
118     @OptIn(ExperimentalCoroutinesApi::class)
119     fun getStorageUnsafe(): Maybe<A> =
120         if (_current.isCompleted) just(_current.getCompleted()) else none
121 }
122 
constSnull123 internal fun <A> constS(name: String?, operatorName: String, init: A): TStateImpl<A> =
124     TStateSource(name, operatorName, init, neverImpl)
125 
126 internal inline fun <A> mkState(
127     name: String?,
128     operatorName: String,
129     evalScope: EvalScope,
130     crossinline getChanges: suspend EvalScope.() -> TFlowImpl<A>,
131     init: Deferred<A>,
132 ): TStateImpl<A> {
133     lateinit var state: TStateSource<A>
134     val calm: TFlowImpl<A> =
135         filterNode(getChanges) { new -> new != state.getCurrentWithEpoch(evalScope = this).first }
136             .cached()
137     return TStateSource(name, operatorName, init, calm).also {
138         state = it
139         evalScope.scheduleOutput(
140             OneShot {
141                 calm.activate(evalScope = this, downstream = Schedulable.S(state))?.let {
142                     (connection, needsEval) ->
143                     state.upstreamConnection = connection
144                     if (needsEval) {
145                         schedule(state)
146                     }
147                 }
148             }
149         )
150     }
151 }
152 
calmnull153 private inline fun <A> TFlowImpl<A>.calm(
154     crossinline getState: () -> TStateDerived<A>
155 ): TFlowImpl<A> =
156     filterNode({ this@calm }) { new ->
157             val state = getState()
158             val (current, _) = state.getCurrentWithEpoch(evalScope = this)
159             if (new != current) {
160                 state.setCache(new, epoch)
161                 true
162             } else {
163                 false
164             }
165         }
166         .cached()
167 
mapCheapnull168 internal fun <A, B> TStateImpl<A>.mapCheap(
169     name: String?,
170     operatorName: String,
171     transform: suspend EvalScope.(A) -> B,
172 ): TStateImpl<B> =
173     DerivedMapCheap(name, operatorName, this, mapImpl({ changes }) { transform(it) }, transform)
174 
175 internal class DerivedMapCheap<A, B>(
176     override val name: String?,
177     override val operatorName: String,
178     val upstream: TStateImpl<A>,
179     override val changes: TFlowImpl<B>,
180     private val transform: suspend EvalScope.(A) -> B,
181 ) : TStateImpl<B> {
182 
getCurrentWithEpochnull183     override suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<B, Long> {
184         val (a, epoch) = upstream.getCurrentWithEpoch(evalScope)
185         return evalScope.transform(a) to epoch
186     }
187 
toStringnull188     override fun toString(): String = "${this::class.simpleName}@$hashString"
189 }
190 
191 internal fun <A, B> TStateImpl<A>.map(
192     name: String?,
193     operatorName: String,
194     transform: suspend EvalScope.(A) -> B,
195 ): TStateImpl<B> {
196     lateinit var state: TStateDerived<B>
197     val mappedChanges = mapImpl({ changes }) { transform(it) }.cached().calm { state }
198     state = DerivedMap(name, operatorName, transform, this, mappedChanges)
199     return state
200 }
201 
202 internal class DerivedMap<A, B>(
203     override val name: String?,
204     override val operatorName: String,
205     private val transform: suspend EvalScope.(A) -> B,
206     val upstream: TStateImpl<A>,
207     changes: TFlowImpl<B>,
208 ) : TStateDerived<B>(changes) {
toStringnull209     override fun toString(): String = "${this::class.simpleName}@$hashString"
210 
211     override suspend fun recalc(evalScope: EvalScope): Pair<B, Long>? {
212         val (a, epoch) = upstream.getCurrentWithEpoch(evalScope)
213         return if (epoch > invalidatedEpoch) {
214             evalScope.transform(a) to epoch
215         } else {
216             null
217         }
218     }
219 }
220 
flattennull221 internal fun <A> TStateImpl<TStateImpl<A>>.flatten(name: String?, operator: String): TStateImpl<A> {
222     // emits the current value of the new inner state, when that state is emitted
223     val switchEvents = mapImpl({ changes }) { newInner -> newInner.getCurrentWithEpoch(this).first }
224     // emits the new value of the new inner state when that state is emitted, or
225     // falls back to the current value if a new state is *not* being emitted this
226     // transaction
227     val innerChanges =
228         mapImpl({ changes }) { newInner ->
229             mergeNodes({ switchEvents }, { newInner.changes }) { _, new -> new }
230         }
231     val switchedChanges: TFlowImpl<A> =
232         mapImpl({
233             switchPromptImpl(
234                 getStorage = {
235                     mapOf(Unit to this@flatten.getCurrentWithEpoch(evalScope = this).first.changes)
236                 },
237                 getPatches = { mapImpl({ innerChanges }) { new -> mapOf(Unit to just(new)) } },
238             )
239         }) { map ->
240             map.getValue(Unit)
241         }
242     lateinit var state: DerivedFlatten<A>
243     state = DerivedFlatten(name, operator, this, switchedChanges.calm { state })
244     return state
245 }
246 
247 internal class DerivedFlatten<A>(
248     override val name: String?,
249     override val operatorName: String,
250     val upstream: TStateImpl<TStateImpl<A>>,
251     changes: TFlowImpl<A>,
252 ) : TStateDerived<A>(changes) {
recalcnull253     override suspend fun recalc(evalScope: EvalScope): Pair<A, Long> {
254         val (inner, epoch0) = upstream.getCurrentWithEpoch(evalScope)
255         val (a, epoch1) = inner.getCurrentWithEpoch(evalScope)
256         return a to maxOf(epoch0, epoch1)
257     }
258 
toStringnull259     override fun toString(): String = "${this::class.simpleName}@$hashString"
260 }
261 
262 @Suppress("NOTHING_TO_INLINE")
263 internal inline fun <A, B> TStateImpl<A>.flatMap(
264     name: String?,
265     operatorName: String,
266     noinline transform: suspend EvalScope.(A) -> TStateImpl<B>,
267 ): TStateImpl<B> = map(null, operatorName, transform).flatten(name, operatorName)
268 
269 internal fun <A, B, Z> zipStates(
270     name: String?,
271     operatorName: String,
272     l1: TStateImpl<A>,
273     l2: TStateImpl<B>,
274     transform: suspend EvalScope.(A, B) -> Z,
275 ): TStateImpl<Z> =
276     zipStates(null, operatorName, mapOf(0 to l1, 1 to l2)).map(name, operatorName) {
277         val a = it.getValue(0)
278         val b = it.getValue(1)
279         @Suppress("UNCHECKED_CAST") transform(a as A, b as B)
280     }
281 
zipStatesnull282 internal fun <A, B, C, Z> zipStates(
283     name: String?,
284     operatorName: String,
285     l1: TStateImpl<A>,
286     l2: TStateImpl<B>,
287     l3: TStateImpl<C>,
288     transform: suspend EvalScope.(A, B, C) -> Z,
289 ): TStateImpl<Z> =
290     zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3)).map(name, operatorName) {
291         val a = it.getValue(0)
292         val b = it.getValue(1)
293         val c = it.getValue(2)
294         @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C)
295     }
296 
zipStatesnull297 internal fun <A, B, C, D, Z> zipStates(
298     name: String?,
299     operatorName: String,
300     l1: TStateImpl<A>,
301     l2: TStateImpl<B>,
302     l3: TStateImpl<C>,
303     l4: TStateImpl<D>,
304     transform: suspend EvalScope.(A, B, C, D) -> Z,
305 ): TStateImpl<Z> =
306     zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3, 3 to l4)).map(
307         name,
308         operatorName,
309     ) {
310         val a = it.getValue(0)
311         val b = it.getValue(1)
312         val c = it.getValue(2)
313         val d = it.getValue(3)
314         @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C, d as D)
315     }
316 
zipStatesnull317 internal fun <A, B, C, D, E, Z> zipStates(
318     name: String?,
319     operatorName: String,
320     l1: TStateImpl<A>,
321     l2: TStateImpl<B>,
322     l3: TStateImpl<C>,
323     l4: TStateImpl<D>,
324     l5: TStateImpl<E>,
325     transform: suspend EvalScope.(A, B, C, D, E) -> Z,
326 ): TStateImpl<Z> =
327     zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3, 3 to l4, 4 to l5)).map(
328         name,
329         operatorName,
330     ) {
331         val a = it.getValue(0)
332         val b = it.getValue(1)
333         val c = it.getValue(2)
334         val d = it.getValue(3)
335         val e = it.getValue(4)
336         @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C, d as D, e as E)
337     }
338 
zipStatesnull339 internal fun <K : Any, A> zipStates(
340     name: String?,
341     operatorName: String,
342     states: Map<K, TStateImpl<A>>,
343 ): TStateImpl<Map<K, A>> {
344     if (states.isEmpty()) return constS(name, operatorName, emptyMap())
345     val stateChanges: Map<K, TFlowImpl<A>> = states.mapValues { it.value.changes }
346     lateinit var state: DerivedZipped<K, A>
347     // No need for calm; invariant ensures that changes will only emit when there's a difference
348     val changes: TFlowImpl<Map<K, A>> =
349         mapImpl({
350             switchDeferredImpl(getStorage = { stateChanges }, getPatches = { neverImpl })
351         }) { patch ->
352             states
353                 .mapValues { (k, v) ->
354                     if (k in patch) {
355                         patch.getValue(k)
356                     } else {
357                         v.getCurrentWithEpoch(evalScope = this).first
358                     }
359                 }
360                 .also { state.setCache(it, epoch) }
361         }
362     state = DerivedZipped(name, operatorName, states, changes)
363     return state
364 }
365 
366 internal class DerivedZipped<K : Any, A>(
367     override val name: String?,
368     override val operatorName: String,
369     val upstream: Map<K, TStateImpl<A>>,
370     changes: TFlowImpl<Map<K, A>>,
371 ) : TStateDerived<Map<K, A>>(changes) {
recalcnull372     override suspend fun recalc(evalScope: EvalScope): Pair<Map<K, A>, Long> {
373         val newEpoch = AtomicLong()
374         return upstream.mapValuesParallel {
375             val (a, epoch) = it.value.getCurrentWithEpoch(evalScope)
376             newEpoch.accumulateAndGet(epoch, ::maxOf)
377             a
378         } to newEpoch.get()
379     }
380 
toStringnull381     override fun toString(): String = "${this::class.simpleName}@$hashString"
382 }
383 
384 @Suppress("NOTHING_TO_INLINE")
385 internal inline fun <A> zipStates(
386     name: String?,
387     operatorName: String,
388     states: List<TStateImpl<A>>,
389 ): TStateImpl<List<A>> =
390     if (states.isEmpty()) {
391         constS(name, operatorName, emptyList())
392     } else {
393         zipStates(null, operatorName, states.asIterable().associateByIndex()).mapCheap(
394             name,
395             operatorName,
<lambda>null396         ) {
397             it.values.toList()
398         }
399     }
400