1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.role.persistence
18 
19 import android.content.ApexEnvironment
20 import android.content.Context
21 import android.os.Process
22 import android.os.UserHandle
23 import android.platform.test.annotations.RequiresFlagsDisabled
24 import android.platform.test.annotations.RequiresFlagsEnabled
25 import android.platform.test.flag.junit.DeviceFlagsValueProvider
26 import androidx.test.platform.app.InstrumentationRegistry
27 import com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession
28 import com.android.permission.flags.Flags
29 import com.google.common.truth.Truth.assertThat
30 import java.io.File
31 import org.junit.After
32 import org.junit.Assume.assumeFalse
33 import org.junit.Before
34 import org.junit.Rule
35 import org.junit.Test
36 import org.junit.runner.RunWith
37 import org.junit.runners.Parameterized
38 import org.mockito.ArgumentMatchers.any
39 import org.mockito.ArgumentMatchers.eq
40 import org.mockito.Mock
41 import org.mockito.Mockito.`when`
42 import org.mockito.MockitoAnnotations.initMocks
43 import org.mockito.MockitoSession
44 import org.mockito.quality.Strictness
45 
46 @RunWith(Parameterized::class)
47 class RolesPersistenceTest {
48     private val context = InstrumentationRegistry.getInstrumentation().context
49 
50     private lateinit var mockDataDirectory: File
51     private lateinit var mockitoSession: MockitoSession
52     @Mock lateinit var apexEnvironment: ApexEnvironment
53     @Parameterized.Parameter(0) lateinit var stateVersion: StateVersion
54     private lateinit var state: RolesState
55 
<lambda>null56     private val persistence = RolesPersistenceImpl {}
57     private val defaultRoles = mapOf(ROLE_NAME to setOf(HOLDER_1, HOLDER_2))
58     private val activeUserIds = mapOf(ROLE_NAME to USER_ID)
59     private val stateVersionUndefined = RolesState(VERSION_UNDEFINED, PACKAGE_HASH, defaultRoles)
60     private val stateVersionFallbackMigrated =
61         RolesState(VERSION_FALLBACK_MIGRATED, PACKAGE_HASH, defaultRoles, setOf(ROLE_NAME))
62     private val stateVersionActiveUserIds =
63         RolesState(
64             VERSION_ACTIVE_USER_IDS,
65             PACKAGE_HASH,
66             defaultRoles,
67             setOf(ROLE_NAME),
68             activeUserIds,
69         )
70     private val user = Process.myUserHandle()
71 
72     @get:Rule val flagsRule = DeviceFlagsValueProvider.createCheckFlagsRule()
73 
74     @Before
setUpnull75     fun setUp() {
76         createMockDataDirectory()
77         mockApexEnvironment()
78         state = getState()
79     }
80 
createMockDataDirectorynull81     private fun createMockDataDirectory() {
82         mockDataDirectory = context.getDir("mock_data", Context.MODE_PRIVATE)
83         mockDataDirectory.listFiles()!!.forEach { assertThat(it.deleteRecursively()).isTrue() }
84     }
85 
mockApexEnvironmentnull86     private fun mockApexEnvironment() {
87         initMocks(this)
88         mockitoSession =
89             mockitoSession()
90                 .mockStatic(ApexEnvironment::class.java)
91                 .strictness(Strictness.LENIENT)
92                 .startMocking()
93         `when`(ApexEnvironment.getApexEnvironment(eq(APEX_MODULE_NAME))).thenReturn(apexEnvironment)
94         `when`(apexEnvironment.getDeviceProtectedDataDirForUser(any(UserHandle::class.java))).then {
95             File(mockDataDirectory, it.arguments[0].toString()).also { it.mkdirs() }
96         }
97     }
98 
99     @After
finishMockingApexEnvironmentnull100     fun finishMockingApexEnvironment() {
101         mockitoSession.finishMocking()
102     }
103 
104     @RequiresFlagsDisabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED)
105     @Test
testWriteReadnull106     fun testWriteRead() {
107         assumeFalse(stateVersion == StateVersion.VERSION_ACTIVE_USER_IDS)
108         persistence.writeForUser(state, user)
109         val persistedState = persistence.readForUser(user)
110 
111         assertThat(persistedState).isEqualTo(state)
112     }
113 
114     @RequiresFlagsEnabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED)
115     @Test
testWriteRead_supportsActiveUsernull116     fun testWriteRead_supportsActiveUser() {
117         persistence.writeForUser(state, user)
118         val persistedState = persistence.readForUser(user)
119 
120         assertThat(persistedState).isEqualTo(state)
121     }
122 
123     @RequiresFlagsDisabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED)
124     @Test
testWriteCorruptReadFromReserveCopynull125     fun testWriteCorruptReadFromReserveCopy() {
126         assumeFalse(stateVersion == StateVersion.VERSION_ACTIVE_USER_IDS)
127         persistence.writeForUser(state, user)
128         // Corrupt the primary file.
129         RolesPersistenceImpl.getFile(user)
130             .writeText("<roles version=\"-1\"><role name=\"com.foo.bar\"><holder")
131         val persistedState = persistence.readForUser(user)
132 
133         assertThat(persistedState).isEqualTo(state)
134     }
135 
136     @RequiresFlagsEnabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED)
137     @Test
testWriteCorruptReadFromReserveCopy_supportsActiveUsernull138     fun testWriteCorruptReadFromReserveCopy_supportsActiveUser() {
139         persistence.writeForUser(state, user)
140         // Corrupt the primary file.
141         RolesPersistenceImpl.getFile(user)
142             .writeText("<roles version=\"-1\"><role name=\"com.foo.bar\"><holder")
143         val persistedState = persistence.readForUser(user)
144 
145         assertThat(persistedState).isEqualTo(state)
146     }
147 
148     @Test
testDeletenull149     fun testDelete() {
150         persistence.writeForUser(state, user)
151         persistence.deleteForUser(user)
152         val persistedState = persistence.readForUser(user)
153 
154         assertThat(persistedState).isNull()
155     }
156 
getStatenull157     private fun getState(): RolesState =
158         when (stateVersion) {
159             StateVersion.VERSION_UNDEFINED -> stateVersionUndefined
160             StateVersion.VERSION_FALLBACK_MIGRATED -> stateVersionFallbackMigrated
161             StateVersion.VERSION_ACTIVE_USER_IDS -> stateVersionActiveUserIds
162         }
163 
164     enum class StateVersion {
165         VERSION_UNDEFINED,
166         VERSION_FALLBACK_MIGRATED,
167         VERSION_ACTIVE_USER_IDS,
168     }
169 
170     companion object {
171         @Parameterized.Parameters(name = "{0}")
172         @JvmStatic
datanull173         fun data(): Array<StateVersion> = StateVersion.values()
174 
175         private const val VERSION_UNDEFINED = -1
176         private const val VERSION_FALLBACK_MIGRATED = 1
177         private const val VERSION_ACTIVE_USER_IDS = 2
178         private const val APEX_MODULE_NAME = "com.android.permission"
179         private const val PACKAGE_HASH = "packagesHash"
180         private const val ROLE_NAME = "roleName"
181         private const val HOLDER_1 = "holder1"
182         private const val HOLDER_2 = "holder2"
183         private const val USER_ID = 10
184     }
185 }
186