1 /*
<lambda>null2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package com.android.systemui.kairos.internal
18
19 import com.android.systemui.kairos.internal.util.Key
20 import com.android.systemui.kairos.internal.util.associateByIndex
21 import com.android.systemui.kairos.internal.util.hashString
22 import com.android.systemui.kairos.internal.util.mapValuesParallel
23 import com.android.systemui.kairos.util.Just
24 import com.android.systemui.kairos.util.Maybe
25 import com.android.systemui.kairos.util.just
26 import com.android.systemui.kairos.util.none
27 import java.util.concurrent.atomic.AtomicLong
28 import kotlinx.coroutines.CompletableDeferred
29 import kotlinx.coroutines.CoroutineStart
30 import kotlinx.coroutines.Deferred
31 import kotlinx.coroutines.ExperimentalCoroutinesApi
32
33 internal sealed interface TStateImpl<out A> {
34 val name: String?
35 val operatorName: String
36 val changes: TFlowImpl<A>
37
38 suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long>
39 }
40
41 internal sealed class TStateDerived<A>(override val changes: TFlowImpl<A>) :
42 TStateImpl<A>, Key<Deferred<Pair<A, Long>>> {
43
44 @Volatile
45 var invalidatedEpoch = Long.MIN_VALUE
46 private set
47
48 @Volatile
49 protected var cache: Any? = EmptyCache
50 private set
51
getCurrentWithEpochnull52 override suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
53 evalScope.transactionStore
54 .getOrPut(this) { evalScope.deferAsync(CoroutineStart.LAZY) { pull(evalScope) } }
55 .await()
56
pullnull57 suspend fun pull(evalScope: EvalScope): Pair<A, Long> {
58 @Suppress("UNCHECKED_CAST")
59 return recalc(evalScope)?.also { (a, epoch) -> setCache(a, epoch) }
60 ?: ((cache as A) to invalidatedEpoch)
61 }
62
setCachenull63 fun setCache(value: A, epoch: Long) {
64 if (epoch > invalidatedEpoch) {
65 cache = value
66 invalidatedEpoch = epoch
67 }
68 }
69
getCachedUnsafenull70 fun getCachedUnsafe(): Maybe<A> {
71 @Suppress("UNCHECKED_CAST")
72 return if (cache == EmptyCache) none else just(cache as A)
73 }
74
recalcnull75 protected abstract suspend fun recalc(evalScope: EvalScope): Pair<A, Long>?
76
77 private data object EmptyCache
78 }
79
80 internal class TStateSource<A>(
81 override val name: String?,
82 override val operatorName: String,
83 init: Deferred<A>,
84 override val changes: TFlowImpl<A>,
85 ) : TStateImpl<A> {
86 constructor(
87 name: String?,
88 operatorName: String,
89 init: A,
90 changes: TFlowImpl<A>,
91 ) : this(name, operatorName, CompletableDeferred(init), changes)
92
93 lateinit var upstreamConnection: NodeConnection<A>
94
95 // Note: Don't need to synchronize; we will never interleave reads and writes, since all writes
96 // are performed at the end of a network step, after any reads would have taken place.
97
98 @Volatile private var _current: Deferred<A> = init
99 @Volatile
100 var writeEpoch = 0L
101 private set
102
103 override suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<A, Long> =
104 _current.await() to writeEpoch
105
106 /** called by network after eval phase has completed */
107 suspend fun updateState(evalScope: EvalScope) {
108 // write the latch
109 val eventResult = upstreamConnection.getPushEvent(evalScope)
110 if (eventResult is Just) {
111 _current = CompletableDeferred(eventResult.value)
112 writeEpoch = evalScope.epoch
113 }
114 }
115
116 override fun toString(): String = "TStateImpl(changes=$changes, current=$_current)"
117
118 @OptIn(ExperimentalCoroutinesApi::class)
119 fun getStorageUnsafe(): Maybe<A> =
120 if (_current.isCompleted) just(_current.getCompleted()) else none
121 }
122
constSnull123 internal fun <A> constS(name: String?, operatorName: String, init: A): TStateImpl<A> =
124 TStateSource(name, operatorName, init, neverImpl)
125
126 internal inline fun <A> mkState(
127 name: String?,
128 operatorName: String,
129 evalScope: EvalScope,
130 crossinline getChanges: suspend EvalScope.() -> TFlowImpl<A>,
131 init: Deferred<A>,
132 ): TStateImpl<A> {
133 lateinit var state: TStateSource<A>
134 val calm: TFlowImpl<A> =
135 filterNode(getChanges) { new -> new != state.getCurrentWithEpoch(evalScope = this).first }
136 .cached()
137 return TStateSource(name, operatorName, init, calm).also {
138 state = it
139 evalScope.scheduleOutput(
140 OneShot {
141 calm.activate(evalScope = this, downstream = Schedulable.S(state))?.let {
142 (connection, needsEval) ->
143 state.upstreamConnection = connection
144 if (needsEval) {
145 schedule(state)
146 }
147 }
148 }
149 )
150 }
151 }
152
calmnull153 private inline fun <A> TFlowImpl<A>.calm(
154 crossinline getState: () -> TStateDerived<A>
155 ): TFlowImpl<A> =
156 filterNode({ this@calm }) { new ->
157 val state = getState()
158 val (current, _) = state.getCurrentWithEpoch(evalScope = this)
159 if (new != current) {
160 state.setCache(new, epoch)
161 true
162 } else {
163 false
164 }
165 }
166 .cached()
167
mapCheapnull168 internal fun <A, B> TStateImpl<A>.mapCheap(
169 name: String?,
170 operatorName: String,
171 transform: suspend EvalScope.(A) -> B,
172 ): TStateImpl<B> =
173 DerivedMapCheap(name, operatorName, this, mapImpl({ changes }) { transform(it) }, transform)
174
175 internal class DerivedMapCheap<A, B>(
176 override val name: String?,
177 override val operatorName: String,
178 val upstream: TStateImpl<A>,
179 override val changes: TFlowImpl<B>,
180 private val transform: suspend EvalScope.(A) -> B,
181 ) : TStateImpl<B> {
182
getCurrentWithEpochnull183 override suspend fun getCurrentWithEpoch(evalScope: EvalScope): Pair<B, Long> {
184 val (a, epoch) = upstream.getCurrentWithEpoch(evalScope)
185 return evalScope.transform(a) to epoch
186 }
187
toStringnull188 override fun toString(): String = "${this::class.simpleName}@$hashString"
189 }
190
191 internal fun <A, B> TStateImpl<A>.map(
192 name: String?,
193 operatorName: String,
194 transform: suspend EvalScope.(A) -> B,
195 ): TStateImpl<B> {
196 lateinit var state: TStateDerived<B>
197 val mappedChanges = mapImpl({ changes }) { transform(it) }.cached().calm { state }
198 state = DerivedMap(name, operatorName, transform, this, mappedChanges)
199 return state
200 }
201
202 internal class DerivedMap<A, B>(
203 override val name: String?,
204 override val operatorName: String,
205 private val transform: suspend EvalScope.(A) -> B,
206 val upstream: TStateImpl<A>,
207 changes: TFlowImpl<B>,
208 ) : TStateDerived<B>(changes) {
toStringnull209 override fun toString(): String = "${this::class.simpleName}@$hashString"
210
211 override suspend fun recalc(evalScope: EvalScope): Pair<B, Long>? {
212 val (a, epoch) = upstream.getCurrentWithEpoch(evalScope)
213 return if (epoch > invalidatedEpoch) {
214 evalScope.transform(a) to epoch
215 } else {
216 null
217 }
218 }
219 }
220
flattennull221 internal fun <A> TStateImpl<TStateImpl<A>>.flatten(name: String?, operator: String): TStateImpl<A> {
222 // emits the current value of the new inner state, when that state is emitted
223 val switchEvents = mapImpl({ changes }) { newInner -> newInner.getCurrentWithEpoch(this).first }
224 // emits the new value of the new inner state when that state is emitted, or
225 // falls back to the current value if a new state is *not* being emitted this
226 // transaction
227 val innerChanges =
228 mapImpl({ changes }) { newInner ->
229 mergeNodes({ switchEvents }, { newInner.changes }) { _, new -> new }
230 }
231 val switchedChanges: TFlowImpl<A> =
232 mapImpl({
233 switchPromptImpl(
234 getStorage = {
235 mapOf(Unit to this@flatten.getCurrentWithEpoch(evalScope = this).first.changes)
236 },
237 getPatches = { mapImpl({ innerChanges }) { new -> mapOf(Unit to just(new)) } },
238 )
239 }) { map ->
240 map.getValue(Unit)
241 }
242 lateinit var state: DerivedFlatten<A>
243 state = DerivedFlatten(name, operator, this, switchedChanges.calm { state })
244 return state
245 }
246
247 internal class DerivedFlatten<A>(
248 override val name: String?,
249 override val operatorName: String,
250 val upstream: TStateImpl<TStateImpl<A>>,
251 changes: TFlowImpl<A>,
252 ) : TStateDerived<A>(changes) {
recalcnull253 override suspend fun recalc(evalScope: EvalScope): Pair<A, Long> {
254 val (inner, epoch0) = upstream.getCurrentWithEpoch(evalScope)
255 val (a, epoch1) = inner.getCurrentWithEpoch(evalScope)
256 return a to maxOf(epoch0, epoch1)
257 }
258
toStringnull259 override fun toString(): String = "${this::class.simpleName}@$hashString"
260 }
261
262 @Suppress("NOTHING_TO_INLINE")
263 internal inline fun <A, B> TStateImpl<A>.flatMap(
264 name: String?,
265 operatorName: String,
266 noinline transform: suspend EvalScope.(A) -> TStateImpl<B>,
267 ): TStateImpl<B> = map(null, operatorName, transform).flatten(name, operatorName)
268
269 internal fun <A, B, Z> zipStates(
270 name: String?,
271 operatorName: String,
272 l1: TStateImpl<A>,
273 l2: TStateImpl<B>,
274 transform: suspend EvalScope.(A, B) -> Z,
275 ): TStateImpl<Z> =
276 zipStates(null, operatorName, mapOf(0 to l1, 1 to l2)).map(name, operatorName) {
277 val a = it.getValue(0)
278 val b = it.getValue(1)
279 @Suppress("UNCHECKED_CAST") transform(a as A, b as B)
280 }
281
zipStatesnull282 internal fun <A, B, C, Z> zipStates(
283 name: String?,
284 operatorName: String,
285 l1: TStateImpl<A>,
286 l2: TStateImpl<B>,
287 l3: TStateImpl<C>,
288 transform: suspend EvalScope.(A, B, C) -> Z,
289 ): TStateImpl<Z> =
290 zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3)).map(name, operatorName) {
291 val a = it.getValue(0)
292 val b = it.getValue(1)
293 val c = it.getValue(2)
294 @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C)
295 }
296
zipStatesnull297 internal fun <A, B, C, D, Z> zipStates(
298 name: String?,
299 operatorName: String,
300 l1: TStateImpl<A>,
301 l2: TStateImpl<B>,
302 l3: TStateImpl<C>,
303 l4: TStateImpl<D>,
304 transform: suspend EvalScope.(A, B, C, D) -> Z,
305 ): TStateImpl<Z> =
306 zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3, 3 to l4)).map(
307 name,
308 operatorName,
309 ) {
310 val a = it.getValue(0)
311 val b = it.getValue(1)
312 val c = it.getValue(2)
313 val d = it.getValue(3)
314 @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C, d as D)
315 }
316
zipStatesnull317 internal fun <A, B, C, D, E, Z> zipStates(
318 name: String?,
319 operatorName: String,
320 l1: TStateImpl<A>,
321 l2: TStateImpl<B>,
322 l3: TStateImpl<C>,
323 l4: TStateImpl<D>,
324 l5: TStateImpl<E>,
325 transform: suspend EvalScope.(A, B, C, D, E) -> Z,
326 ): TStateImpl<Z> =
327 zipStates(null, operatorName, mapOf(0 to l1, 1 to l2, 2 to l3, 3 to l4, 4 to l5)).map(
328 name,
329 operatorName,
330 ) {
331 val a = it.getValue(0)
332 val b = it.getValue(1)
333 val c = it.getValue(2)
334 val d = it.getValue(3)
335 val e = it.getValue(4)
336 @Suppress("UNCHECKED_CAST") transform(a as A, b as B, c as C, d as D, e as E)
337 }
338
zipStatesnull339 internal fun <K : Any, A> zipStates(
340 name: String?,
341 operatorName: String,
342 states: Map<K, TStateImpl<A>>,
343 ): TStateImpl<Map<K, A>> {
344 if (states.isEmpty()) return constS(name, operatorName, emptyMap())
345 val stateChanges: Map<K, TFlowImpl<A>> = states.mapValues { it.value.changes }
346 lateinit var state: DerivedZipped<K, A>
347 // No need for calm; invariant ensures that changes will only emit when there's a difference
348 val changes: TFlowImpl<Map<K, A>> =
349 mapImpl({
350 switchDeferredImpl(getStorage = { stateChanges }, getPatches = { neverImpl })
351 }) { patch ->
352 states
353 .mapValues { (k, v) ->
354 if (k in patch) {
355 patch.getValue(k)
356 } else {
357 v.getCurrentWithEpoch(evalScope = this).first
358 }
359 }
360 .also { state.setCache(it, epoch) }
361 }
362 state = DerivedZipped(name, operatorName, states, changes)
363 return state
364 }
365
366 internal class DerivedZipped<K : Any, A>(
367 override val name: String?,
368 override val operatorName: String,
369 val upstream: Map<K, TStateImpl<A>>,
370 changes: TFlowImpl<Map<K, A>>,
371 ) : TStateDerived<Map<K, A>>(changes) {
recalcnull372 override suspend fun recalc(evalScope: EvalScope): Pair<Map<K, A>, Long> {
373 val newEpoch = AtomicLong()
374 return upstream.mapValuesParallel {
375 val (a, epoch) = it.value.getCurrentWithEpoch(evalScope)
376 newEpoch.accumulateAndGet(epoch, ::maxOf)
377 a
378 } to newEpoch.get()
379 }
380
toStringnull381 override fun toString(): String = "${this::class.simpleName}@$hashString"
382 }
383
384 @Suppress("NOTHING_TO_INLINE")
385 internal inline fun <A> zipStates(
386 name: String?,
387 operatorName: String,
388 states: List<TStateImpl<A>>,
389 ): TStateImpl<List<A>> =
390 if (states.isEmpty()) {
391 constS(name, operatorName, emptyList())
392 } else {
393 zipStates(null, operatorName, states.asIterable().associateByIndex()).mapCheap(
394 name,
395 operatorName,
<lambda>null396 ) {
397 it.values.toList()
398 }
399 }
400