1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.net.NetworkStats
20 import com.android.net.module.util.ArrayTrackRecord
21 import kotlin.test.assertEquals
22 import kotlin.test.assertTrue
23 import kotlin.test.fail
24 
25 private const val DEFAULT_TIMEOUT_MS = 3000L
26 
27 open class TestableNetworkStatsProviderCbBinder : NetworkStatsProviderCbStubCompat() {
28     sealed class CallbackType {
29         data class NotifyStatsUpdated(
30             val token: Int,
31             val ifaceStats: NetworkStats,
32             val uidStats: NetworkStats
33         ) : CallbackType()
34         object NotifyWarningReached : CallbackType()
35         object NotifyLimitReached : CallbackType()
36         object NotifyWarningOrLimitReached : CallbackType()
37         object NotifyAlertReached : CallbackType()
38         object Unregister : CallbackType()
39     }
40 
41     private val history = ArrayTrackRecord<CallbackType>().ReadHead()
42 
notifyStatsUpdatednull43     override fun notifyStatsUpdated(token: Int, ifaceStats: NetworkStats, uidStats: NetworkStats) {
44         history.add(CallbackType.NotifyStatsUpdated(token, ifaceStats, uidStats))
45     }
46 
notifyWarningReachednull47     override fun notifyWarningReached() {
48         history.add(CallbackType.NotifyWarningReached)
49     }
50 
notifyLimitReachednull51     override fun notifyLimitReached() {
52         history.add(CallbackType.NotifyLimitReached)
53     }
54 
notifyWarningOrLimitReachednull55     override fun notifyWarningOrLimitReached() {
56         // Older callback is split into notifyLimitReached and notifyWarningReached in T.
57         history.add(CallbackType.NotifyWarningOrLimitReached)
58     }
59 
notifyAlertReachednull60     override fun notifyAlertReached() {
61         history.add(CallbackType.NotifyAlertReached)
62     }
63 
unregisternull64     override fun unregister() {
65         history.add(CallbackType.Unregister)
66     }
67 
expectNotifyStatsUpdatednull68     fun expectNotifyStatsUpdated() {
69         val event = history.poll(DEFAULT_TIMEOUT_MS)
70         assertTrue(event is CallbackType.NotifyStatsUpdated)
71     }
72 
expectNotifyStatsUpdatednull73     fun expectNotifyStatsUpdated(ifaceStats: NetworkStats, uidStats: NetworkStats) {
74         val event = history.poll(DEFAULT_TIMEOUT_MS)!!
75         if (event !is CallbackType.NotifyStatsUpdated) {
76             throw Exception("Expected NotifyStatsUpdated callback, but got ${event::class}")
77         }
78         // TODO: verify token.
79         assertNetworkStatsEquals(ifaceStats, event.ifaceStats)
80         assertNetworkStatsEquals(uidStats, event.uidStats)
81     }
82 
expectNotifyWarningReachednull83     fun expectNotifyWarningReached() =
84             assertEquals(CallbackType.NotifyWarningReached, history.poll(DEFAULT_TIMEOUT_MS))
85 
86     fun expectNotifyLimitReached() =
87             assertEquals(CallbackType.NotifyLimitReached, history.poll(DEFAULT_TIMEOUT_MS))
88 
89     fun expectNotifyWarningOrLimitReached() =
90             assertEquals(CallbackType.NotifyWarningOrLimitReached, history.poll(DEFAULT_TIMEOUT_MS))
91 
92     fun expectNotifyAlertReached() =
93             assertEquals(CallbackType.NotifyAlertReached, history.poll(DEFAULT_TIMEOUT_MS))
94 
95     // Assert there is no callback in current queue.
96     fun assertNoCallback() {
97         val cb = history.poll(0)
98         cb?.let { fail("Expected no callback but got $cb") }
99     }
100 }
101