1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.testutils 18 19 import android.net.ConnectivityManager 20 import android.net.ConnectivityManager.NetworkCallback 21 import android.net.Network 22 import android.net.NetworkCapabilities 23 import android.net.NetworkRequest 24 import android.os.Handler 25 import androidx.test.platform.app.InstrumentationRegistry 26 import com.android.testutils.RecorderCallback.CallbackEntry 27 import java.util.Collections 28 import kotlin.test.fail 29 import org.junit.rules.TestRule 30 import org.junit.runner.Description 31 import org.junit.runners.model.Statement 32 33 /** 34 * A rule to file [NetworkCallback]s to request or watch networks. 35 * 36 * The callbacks filed in test methods are automatically unregistered when the method completes. 37 */ 38 class AutoReleaseNetworkCallbackRule : NetworkCallbackHelper(), TestRule { applynull39 override fun apply(base: Statement, description: Description): Statement { 40 return RequestCellNetworkStatement(base, description) 41 } 42 43 private inner class RequestCellNetworkStatement( 44 private val base: Statement, 45 private val description: Description 46 ) : Statement() { evaluatenull47 override fun evaluate() { 48 tryTest { 49 base.evaluate() 50 } cleanup { 51 unregisterAll() 52 } 53 } 54 } 55 } 56 57 /** 58 * Helps file [NetworkCallback]s to request or watch networks, keeping track of them for cleanup. 59 */ 60 open class NetworkCallbackHelper { <lambda>null61 private val cm by lazy { 62 InstrumentationRegistry.getInstrumentation().context 63 .getSystemService(ConnectivityManager::class.java) 64 ?: fail("ConnectivityManager not found") 65 } 66 private val cbToCleanup = Collections.synchronizedSet(mutableSetOf<NetworkCallback>()) 67 private var cellRequestCb: TestableNetworkCallback? = null 68 69 /** 70 * Convenience method to request a cell network, similarly to [requestNetwork]. 71 * 72 * The rule will keep tract of a single cell network request, which can be unrequested manually 73 * using [unrequestCell]. 74 */ requestCellnull75 fun requestCell(): Network { 76 if (cellRequestCb != null) { 77 fail("Cell network was already requested") 78 } 79 val cb = requestNetwork( 80 NetworkRequest.Builder() 81 .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR) 82 .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET) 83 .build() 84 ) 85 cellRequestCb = cb 86 return cb.expect<CallbackEntry.Available>( 87 errorMsg = "Cell network not available. " + 88 "Please ensure the device has working mobile data." 89 ).network 90 } 91 92 /** 93 * Unrequest a cell network requested through [requestCell]. 94 */ unrequestCellnull95 fun unrequestCell() { 96 val cb = cellRequestCb ?: fail("Cell network was not requested") 97 unregisterNetworkCallback(cb) 98 cellRequestCb = null 99 } 100 addCallbacknull101 private fun <T> addCallback( 102 cb: T, 103 registrar: (NetworkCallback) -> Unit 104 ): T where T : NetworkCallback { 105 registrar(cb) 106 cbToCleanup.add(cb) 107 return cb 108 } 109 110 /** 111 * File a request for a Network. 112 * 113 * This will fail tests (throw) if the cell network cannot be obtained, or if it was already 114 * requested. 115 * 116 * Tests may call [unregisterNetworkCallback] once they are done using the returned [Network], 117 * otherwise it will be automatically unrequested after the test. 118 */ 119 @JvmOverloads requestNetworknull120 fun requestNetwork( 121 request: NetworkRequest, 122 cb: TestableNetworkCallback = TestableNetworkCallback(), 123 handler: Handler? = null 124 ) = addCallback(cb) { 125 if (handler == null) { 126 cm.requestNetwork(request, it) 127 } else { 128 cm.requestNetwork(request, it, handler) 129 } 130 } 131 132 /** 133 * Overload of [requestNetwork] that allows specifying a timeout. 134 */ 135 @JvmOverloads requestNetworknull136 fun requestNetwork( 137 request: NetworkRequest, 138 cb: TestableNetworkCallback = TestableNetworkCallback(), 139 timeoutMs: Int, 140 ) = addCallback(cb) { cm.requestNetwork(request, it, timeoutMs) } 141 142 /** 143 * File a callback for a NetworkRequest. 144 * 145 * Tests may call [unregisterNetworkCallback] once they are done using the returned [Network], 146 * otherwise it will be automatically unrequested after the test. 147 */ 148 @JvmOverloads registerNetworkCallbacknull149 fun registerNetworkCallback( 150 request: NetworkRequest 151 ): TestableNetworkCallback = registerNetworkCallback(request, TestableNetworkCallback()) 152 153 /** 154 * File a callback for a NetworkRequest. 155 * 156 * Tests may call [unregisterNetworkCallback] once they are done using the returned [Network], 157 * otherwise it will be automatically unrequested after the test. 158 */ 159 fun <T> registerNetworkCallback( 160 request: NetworkRequest, 161 cb: T 162 ) where T : NetworkCallback = addCallback(cb) { cm.registerNetworkCallback(request, it) } 163 164 /** 165 * @see ConnectivityManager.registerDefaultNetworkCallback 166 */ 167 @JvmOverloads registerDefaultNetworkCallbacknull168 fun registerDefaultNetworkCallback( 169 cb: TestableNetworkCallback = TestableNetworkCallback(), 170 handler: Handler? = null 171 ) = addCallback(cb) { 172 if (handler == null) { 173 cm.registerDefaultNetworkCallback(it) 174 } else { 175 cm.registerDefaultNetworkCallback(it, handler) 176 } 177 } 178 179 /** 180 * @see ConnectivityManager.registerSystemDefaultNetworkCallback 181 */ 182 @JvmOverloads registerSystemDefaultNetworkCallbacknull183 fun registerSystemDefaultNetworkCallback( 184 cb: TestableNetworkCallback = TestableNetworkCallback(), 185 handler: Handler 186 ) = addCallback(cb) { cm.registerSystemDefaultNetworkCallback(it, handler) } 187 188 /** 189 * @see ConnectivityManager.registerDefaultNetworkCallbackForUid 190 */ 191 @JvmOverloads registerDefaultNetworkCallbackForUidnull192 fun registerDefaultNetworkCallbackForUid( 193 uid: Int, 194 cb: TestableNetworkCallback = TestableNetworkCallback(), 195 handler: Handler 196 ) = addCallback(cb) { cm.registerDefaultNetworkCallbackForUid(uid, it, handler) } 197 198 /** 199 * @see ConnectivityManager.registerBestMatchingNetworkCallback 200 */ 201 @JvmOverloads registerBestMatchingNetworkCallbacknull202 fun registerBestMatchingNetworkCallback( 203 request: NetworkRequest, 204 cb: TestableNetworkCallback = TestableNetworkCallback(), 205 handler: Handler 206 ) = addCallback(cb) { cm.registerBestMatchingNetworkCallback(request, it, handler) } 207 208 /** 209 * @see ConnectivityManager.requestBackgroundNetwork 210 */ 211 @JvmOverloads requestBackgroundNetworknull212 fun requestBackgroundNetwork( 213 request: NetworkRequest, 214 cb: TestableNetworkCallback = TestableNetworkCallback(), 215 handler: Handler 216 ) = addCallback(cb) { cm.requestBackgroundNetwork(request, it, handler) } 217 218 /** 219 * Unregister a callback filed using registration methods in this class. 220 */ unregisterNetworkCallbacknull221 fun unregisterNetworkCallback(cb: NetworkCallback) { 222 cm.unregisterNetworkCallback(cb) 223 cbToCleanup.remove(cb) 224 } 225 226 /** 227 * Unregister all callbacks that were filed using registration methods in this class. 228 */ unregisterAllnull229 fun unregisterAll() { 230 cbToCleanup.forEach { cm.unregisterNetworkCallback(it) } 231 cbToCleanup.clear() 232 cellRequestCb = null 233 } 234 } 235