1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.networkstack.tethering;
18 
19 import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
20 import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
21 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
22 
23 import static com.android.networkstack.apishim.common.ShimUtils.isAtLeastS;
24 
25 import static org.junit.Assert.assertFalse;
26 import static org.junit.Assert.fail;
27 
28 import android.content.Context;
29 import android.content.Intent;
30 import android.net.ConnectivityManager;
31 import android.net.IConnectivityManager;
32 import android.net.LinkProperties;
33 import android.net.Network;
34 import android.net.NetworkCapabilities;
35 import android.net.NetworkInfo;
36 import android.net.NetworkRequest;
37 import android.os.Handler;
38 import android.os.UserHandle;
39 import android.util.ArrayMap;
40 
41 import androidx.annotation.NonNull;
42 import androidx.annotation.Nullable;
43 
44 import java.util.Map;
45 import java.util.Objects;
46 
47 /**
48  * Simulates upstream switching and sending NetworkCallbacks and CONNECTIVITY_ACTION broadcasts.
49  *
50  * Unlike any real networking code, this class is single-threaded and entirely synchronous.
51  * The effects of all method calls (including sending fake broadcasts, sending callbacks, etc.) are
52  * performed immediately on the caller's thread before returning.
53  *
54  * TODO: this duplicates a fair amount of code from ConnectivityManager and ConnectivityService.
55  * Consider using a ConnectivityService object instead, as used in ConnectivityServiceTest.
56  *
57  * Things to consider:
58  * - ConnectivityService uses a real handler for realism, and these test use TestLooper (or even
59  *   invoke callbacks directly inline) for determinism. Using a real ConnectivityService would
60  *   require adding dispatchAll() calls and migrating to handlers.
61  * - ConnectivityService does not provide a way to order CONNECTIVITY_ACTION before or after the
62  *   NetworkCallbacks for the same network change. That ability is useful because the upstream
63  *   selection code in Tethering is vulnerable to race conditions, due to its reliance on multiple
64  *   separate NetworkCallbacks and BroadcastReceivers, each of which trigger different types of
65  *   updates. If/when the upstream selection code is refactored to a more level-triggered model
66  *   (e.g., with an idempotent function that takes into account all state every time any part of
67  *   that state changes), this may become less important or unnecessary.
68  */
69 public class TestConnectivityManager extends ConnectivityManager {
70     public static final boolean BROADCAST_FIRST = false;
71     public static final boolean CALLBACKS_FIRST = true;
72 
73     final Map<NetworkCallback, Handler> mAllCallbacks = new ArrayMap<>();
74     // This contains the callbacks tracking the system default network, whether it's registered
75     // with registerSystemDefaultNetworkCallback (S+) or with a custom request (R-).
76     final Map<NetworkCallback, Handler> mTrackingDefault = new ArrayMap<>();
77     final Map<NetworkCallback, NetworkRequestInfo> mListening = new ArrayMap<>();
78     final Map<NetworkCallback, NetworkRequestInfo> mRequested = new ArrayMap<>();
79     final Map<NetworkCallback, Integer> mLegacyTypeMap = new ArrayMap<>();
80 
81     private final Context mContext;
82 
83     private int mNetworkId = 100;
84     private TestNetworkAgent mDefaultNetwork = null;
85 
86     /**
87      * Constructs a TestConnectivityManager.
88      * @param ctx the context to use. Must be a fake or a mock because otherwise the test will
89      *            attempt to send real broadcasts and resulting in permission denials.
90      * @param svc an IConnectivityManager. Should be a fake or a mock.
91      */
TestConnectivityManager(Context ctx, IConnectivityManager svc)92     public TestConnectivityManager(Context ctx, IConnectivityManager svc) {
93         super(ctx, svc);
94         mContext = ctx;
95     }
96 
97     static class NetworkRequestInfo {
98         public final NetworkRequest request;
99         public final Handler handler;
NetworkRequestInfo(NetworkRequest r, Handler h)100         NetworkRequestInfo(NetworkRequest r, Handler h) {
101             request = r;
102             handler = h;
103         }
104     }
105 
hasNoCallbacks()106     boolean hasNoCallbacks() {
107         return mAllCallbacks.isEmpty()
108                 && mTrackingDefault.isEmpty()
109                 && mListening.isEmpty()
110                 && mRequested.isEmpty()
111                 && mLegacyTypeMap.isEmpty();
112     }
113 
onlyHasDefaultCallbacks()114     boolean onlyHasDefaultCallbacks() {
115         return (mAllCallbacks.size() == 1)
116                 && (mTrackingDefault.size() == 1)
117                 && mListening.isEmpty()
118                 && mRequested.isEmpty()
119                 && mLegacyTypeMap.isEmpty();
120     }
121 
isListeningForAll()122     boolean isListeningForAll() {
123         final NetworkCapabilities empty = new NetworkCapabilities();
124         empty.clearAll();
125 
126         for (NetworkRequestInfo nri : mListening.values()) {
127             if (nri.request.networkCapabilities.equalRequestableCapabilities(empty)) {
128                 return true;
129             }
130         }
131         return false;
132     }
133 
getNetworkId()134     int getNetworkId() {
135         return ++mNetworkId;
136     }
137 
sendDefaultNetworkBroadcasts(TestNetworkAgent formerDefault, TestNetworkAgent defaultNetwork)138     private void sendDefaultNetworkBroadcasts(TestNetworkAgent formerDefault,
139             TestNetworkAgent defaultNetwork) {
140         if (formerDefault != null) {
141             sendConnectivityAction(formerDefault.legacyType, false /* connected */);
142         }
143         if (defaultNetwork != null) {
144             sendConnectivityAction(defaultNetwork.legacyType, true /* connected */);
145         }
146     }
147 
sendDefaultNetworkCallbacks(TestNetworkAgent formerDefault, TestNetworkAgent defaultNetwork)148     private void sendDefaultNetworkCallbacks(TestNetworkAgent formerDefault,
149             TestNetworkAgent defaultNetwork) {
150         for (NetworkCallback cb : mTrackingDefault.keySet()) {
151             final Handler handler = mTrackingDefault.get(cb);
152             if (defaultNetwork != null) {
153                 handler.post(() -> cb.onAvailable(defaultNetwork.networkId));
154                 handler.post(() -> cb.onCapabilitiesChanged(
155                         defaultNetwork.networkId, defaultNetwork.networkCapabilities));
156                 handler.post(() -> cb.onLinkPropertiesChanged(
157                         defaultNetwork.networkId, defaultNetwork.linkProperties));
158             } else if (formerDefault != null) {
159                 handler.post(() -> cb.onLost(formerDefault.networkId));
160             }
161         }
162     }
163 
makeDefaultNetwork(TestNetworkAgent agent, boolean order, @Nullable Runnable inBetween)164     void makeDefaultNetwork(TestNetworkAgent agent, boolean order, @Nullable Runnable inBetween) {
165         if (Objects.equals(mDefaultNetwork, agent)) return;
166 
167         final TestNetworkAgent formerDefault = mDefaultNetwork;
168         mDefaultNetwork = agent;
169 
170         if (order == CALLBACKS_FIRST) {
171             sendDefaultNetworkCallbacks(formerDefault, mDefaultNetwork);
172             if (inBetween != null) inBetween.run();
173             sendDefaultNetworkBroadcasts(formerDefault, mDefaultNetwork);
174         } else {
175             sendDefaultNetworkBroadcasts(formerDefault, mDefaultNetwork);
176             if (inBetween != null) inBetween.run();
177             sendDefaultNetworkCallbacks(formerDefault, mDefaultNetwork);
178         }
179     }
180 
makeDefaultNetwork(TestNetworkAgent agent, boolean order)181     void makeDefaultNetwork(TestNetworkAgent agent, boolean order) {
182         makeDefaultNetwork(agent, order, null /* inBetween */);
183     }
184 
makeDefaultNetwork(TestNetworkAgent agent)185     void makeDefaultNetwork(TestNetworkAgent agent) {
186         makeDefaultNetwork(agent, BROADCAST_FIRST, null /* inBetween */);
187     }
188 
sendLinkProperties(TestNetworkAgent agent, boolean updateDefaultFirst)189     void sendLinkProperties(TestNetworkAgent agent, boolean updateDefaultFirst) {
190         if (!updateDefaultFirst) agent.sendLinkProperties();
191 
192         for (NetworkCallback cb : mTrackingDefault.keySet()) {
193             cb.onLinkPropertiesChanged(agent.networkId, agent.linkProperties);
194         }
195 
196         if (updateDefaultFirst) agent.sendLinkProperties();
197     }
198 
looksLikeDefaultRequest(NetworkRequest req)199     static boolean looksLikeDefaultRequest(NetworkRequest req) {
200         return req.hasCapability(NET_CAPABILITY_INTERNET)
201                 && !req.hasCapability(NET_CAPABILITY_DUN)
202                 && !req.hasTransport(TRANSPORT_CELLULAR);
203     }
204 
205     @Override
requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h)206     public void requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h) {
207         // For R- devices, Tethering will invoke this function in 2 cases, one is to request mobile
208         // network, the other is to track system default network.
209         if (looksLikeDefaultRequest(req)) {
210             assertFalse(isAtLeastS());
211             addTrackDefaultCallback(cb, h);
212         } else {
213             assertFalse(mAllCallbacks.containsKey(cb));
214             mAllCallbacks.put(cb, h);
215             assertFalse(mRequested.containsKey(cb));
216             mRequested.put(cb, new NetworkRequestInfo(req, h));
217         }
218     }
219 
220     @Override
registerSystemDefaultNetworkCallback( @onNull NetworkCallback cb, @NonNull Handler h)221     public void registerSystemDefaultNetworkCallback(
222             @NonNull NetworkCallback cb, @NonNull Handler h) {
223         addTrackDefaultCallback(cb, h);
224     }
225 
addTrackDefaultCallback(@onNull NetworkCallback cb, @NonNull Handler h)226     private void addTrackDefaultCallback(@NonNull NetworkCallback cb, @NonNull Handler h) {
227         assertFalse(mAllCallbacks.containsKey(cb));
228         mAllCallbacks.put(cb, h);
229         assertFalse(mTrackingDefault.containsKey(cb));
230         mTrackingDefault.put(cb, h);
231     }
232 
233     @Override
requestNetwork(NetworkRequest req, NetworkCallback cb)234     public void requestNetwork(NetworkRequest req, NetworkCallback cb) {
235         fail("Should never be called.");
236     }
237 
238     @Override
requestNetwork(NetworkRequest req, int timeoutMs, int legacyType, Handler h, NetworkCallback cb)239     public void requestNetwork(NetworkRequest req,
240             int timeoutMs, int legacyType, Handler h, NetworkCallback cb) {
241         assertFalse(mAllCallbacks.containsKey(cb));
242         NetworkRequest newReq = new NetworkRequest(req.networkCapabilities, legacyType,
243                 -1 /** testId */, req.type);
244         mAllCallbacks.put(cb, h);
245         assertFalse(mRequested.containsKey(cb));
246         mRequested.put(cb, new NetworkRequestInfo(newReq, h));
247         assertFalse(mLegacyTypeMap.containsKey(cb));
248         if (legacyType != ConnectivityManager.TYPE_NONE) {
249             mLegacyTypeMap.put(cb, legacyType);
250         }
251     }
252 
253     @Override
registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h)254     public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h) {
255         assertFalse(mAllCallbacks.containsKey(cb));
256         mAllCallbacks.put(cb, h);
257         assertFalse(mListening.containsKey(cb));
258         mListening.put(cb, new NetworkRequestInfo(req, h));
259     }
260 
261     @Override
registerNetworkCallback(NetworkRequest req, NetworkCallback cb)262     public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb) {
263         fail("Should never be called.");
264     }
265 
266     @Override
registerDefaultNetworkCallback(NetworkCallback cb, Handler h)267     public void registerDefaultNetworkCallback(NetworkCallback cb, Handler h) {
268         fail("Should never be called.");
269     }
270 
271     @Override
registerDefaultNetworkCallback(NetworkCallback cb)272     public void registerDefaultNetworkCallback(NetworkCallback cb) {
273         fail("Should never be called.");
274     }
275 
276     @Override
unregisterNetworkCallback(NetworkCallback cb)277     public void unregisterNetworkCallback(NetworkCallback cb) {
278         if (mTrackingDefault.containsKey(cb)) {
279             mTrackingDefault.remove(cb);
280         } else if (mListening.containsKey(cb)) {
281             mListening.remove(cb);
282         } else if (mRequested.containsKey(cb)) {
283             mRequested.remove(cb);
284             mLegacyTypeMap.remove(cb);
285         } else {
286             fail("Unexpected callback removed");
287         }
288         mAllCallbacks.remove(cb);
289 
290         assertFalse(mAllCallbacks.containsKey(cb));
291         assertFalse(mTrackingDefault.containsKey(cb));
292         assertFalse(mListening.containsKey(cb));
293         assertFalse(mRequested.containsKey(cb));
294     }
295 
sendConnectivityAction(int type, boolean connected)296     private void sendConnectivityAction(int type, boolean connected) {
297         NetworkInfo ni = new NetworkInfo(type, 0 /* subtype */,  getNetworkTypeName(type),
298                 "" /* subtypeName */);
299         NetworkInfo.DetailedState state = connected
300                 ? NetworkInfo.DetailedState.CONNECTED
301                 : NetworkInfo.DetailedState.DISCONNECTED;
302         ni.setDetailedState(state, "" /* reason */, "" /* extraInfo */);
303         Intent intent = new Intent(CONNECTIVITY_ACTION);
304         intent.putExtra(EXTRA_NETWORK_INFO, ni);
305         mContext.sendStickyBroadcastAsUser(intent, UserHandle.ALL);
306     }
307 
308     public static class TestNetworkAgent {
309         public final TestConnectivityManager cm;
310         public final Network networkId;
311         public final NetworkCapabilities networkCapabilities;
312         public final LinkProperties linkProperties;
313         // TODO: delete when tethering no longer uses CONNECTIVITY_ACTION.
314         public final int legacyType;
315 
TestNetworkAgent(TestConnectivityManager cm, NetworkCapabilities nc)316         public TestNetworkAgent(TestConnectivityManager cm, NetworkCapabilities nc) {
317             this.cm = cm;
318             this.networkId = new Network(cm.getNetworkId());
319             networkCapabilities = copy(nc);
320             linkProperties = new LinkProperties();
321             legacyType = toLegacyType(nc);
322         }
323 
TestNetworkAgent(TestConnectivityManager cm, UpstreamNetworkState state)324         public TestNetworkAgent(TestConnectivityManager cm, UpstreamNetworkState state) {
325             this.cm = cm;
326             networkId = state.network;
327             networkCapabilities = state.networkCapabilities;
328             linkProperties = state.linkProperties;
329             this.legacyType = toLegacyType(networkCapabilities);
330         }
331 
toLegacyType(NetworkCapabilities nc)332         private static int toLegacyType(NetworkCapabilities nc) {
333             for (int type = 0; type < ConnectivityManager.TYPE_TEST; type++) {
334                 if (matchesLegacyType(nc, type)) return type;
335             }
336             throw new IllegalArgumentException(("Can't determine legacy type for: ") + nc);
337         }
338 
matchesLegacyType(NetworkCapabilities nc, int legacyType)339         private static boolean matchesLegacyType(NetworkCapabilities nc, int legacyType) {
340             final NetworkCapabilities typeNc;
341             try {
342                 typeNc = ConnectivityManager.networkCapabilitiesForType(legacyType);
343             } catch (IllegalArgumentException e) {
344                 // networkCapabilitiesForType does not support all legacy types.
345                 return false;
346             }
347             return typeNc.satisfiedByNetworkCapabilities(nc);
348         }
349 
matchesLegacyType(int legacyType)350         private boolean matchesLegacyType(int legacyType) {
351             return matchesLegacyType(networkCapabilities, legacyType);
352         }
353 
maybeSendConnectivityBroadcast(boolean connected)354         private void maybeSendConnectivityBroadcast(boolean connected) {
355             for (Integer requestedLegacyType : cm.mLegacyTypeMap.values()) {
356                 if (requestedLegacyType.intValue() == legacyType) {
357                     cm.sendConnectivityAction(legacyType, connected /* connected */);
358                     // In practice, a given network can match only one legacy type.
359                     break;
360                 }
361             }
362         }
363 
fakeConnect()364         public void fakeConnect() {
365             fakeConnect(BROADCAST_FIRST, null);
366         }
367 
fakeConnect(boolean order, @Nullable Runnable inBetween)368         public void fakeConnect(boolean order, @Nullable Runnable inBetween) {
369             if (order == BROADCAST_FIRST) {
370                 maybeSendConnectivityBroadcast(true /* connected */);
371                 if (inBetween != null) inBetween.run();
372             }
373 
374             for (NetworkCallback cb : cm.mListening.keySet()) {
375                 final NetworkRequestInfo nri = cm.mListening.get(cb);
376                 nri.handler.post(() -> cb.onAvailable(networkId));
377                 nri.handler.post(() -> cb.onCapabilitiesChanged(
378                         networkId, copy(networkCapabilities)));
379                 nri.handler.post(() -> cb.onLinkPropertiesChanged(networkId, copy(linkProperties)));
380             }
381 
382             if (order == CALLBACKS_FIRST) {
383                 if (inBetween != null) inBetween.run();
384                 maybeSendConnectivityBroadcast(true /* connected */);
385             }
386             // mTrackingDefault will be updated if/when the caller calls makeDefaultNetwork
387         }
388 
fakeDisconnect()389         public void fakeDisconnect() {
390             fakeDisconnect(BROADCAST_FIRST, null);
391         }
392 
fakeDisconnect(boolean order, @Nullable Runnable inBetween)393         public void fakeDisconnect(boolean order, @Nullable Runnable inBetween) {
394             if (order == BROADCAST_FIRST) {
395                 maybeSendConnectivityBroadcast(false /* connected */);
396                 if (inBetween != null) inBetween.run();
397             }
398 
399             for (NetworkCallback cb : cm.mListening.keySet()) {
400                 cb.onLost(networkId);
401             }
402 
403             if (order == CALLBACKS_FIRST) {
404                 if (inBetween != null) inBetween.run();
405                 maybeSendConnectivityBroadcast(false /* connected */);
406             }
407             // mTrackingDefault will be updated if/when the caller calls makeDefaultNetwork
408         }
409 
sendLinkProperties()410         public void sendLinkProperties() {
411             for (NetworkCallback cb : cm.mListening.keySet()) {
412                 cb.onLinkPropertiesChanged(networkId, copy(linkProperties));
413             }
414         }
415 
416         @Override
toString()417         public String toString() {
418             return String.format("TestNetworkAgent: %s %s", networkId, networkCapabilities);
419         }
420     }
421 
copy(NetworkCapabilities nc)422     static NetworkCapabilities copy(NetworkCapabilities nc) {
423         return new NetworkCapabilities(nc);
424     }
425 
copy(LinkProperties lp)426     static LinkProperties copy(LinkProperties lp) {
427         return new LinkProperties(lp);
428     }
429 }
430