1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.role.persistence 18 19 import android.content.ApexEnvironment 20 import android.content.Context 21 import android.os.Process 22 import android.os.UserHandle 23 import android.platform.test.annotations.RequiresFlagsDisabled 24 import android.platform.test.annotations.RequiresFlagsEnabled 25 import android.platform.test.flag.junit.DeviceFlagsValueProvider 26 import androidx.test.platform.app.InstrumentationRegistry 27 import com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession 28 import com.android.permission.flags.Flags 29 import com.google.common.truth.Truth.assertThat 30 import java.io.File 31 import org.junit.After 32 import org.junit.Assume.assumeFalse 33 import org.junit.Before 34 import org.junit.Rule 35 import org.junit.Test 36 import org.junit.runner.RunWith 37 import org.junit.runners.Parameterized 38 import org.mockito.ArgumentMatchers.any 39 import org.mockito.ArgumentMatchers.eq 40 import org.mockito.Mock 41 import org.mockito.Mockito.`when` 42 import org.mockito.MockitoAnnotations.initMocks 43 import org.mockito.MockitoSession 44 import org.mockito.quality.Strictness 45 46 @RunWith(Parameterized::class) 47 class RolesPersistenceTest { 48 private val context = InstrumentationRegistry.getInstrumentation().context 49 50 private lateinit var mockDataDirectory: File 51 private lateinit var mockitoSession: MockitoSession 52 @Mock lateinit var apexEnvironment: ApexEnvironment 53 @Parameterized.Parameter(0) lateinit var stateVersion: StateVersion 54 private lateinit var state: RolesState 55 <lambda>null56 private val persistence = RolesPersistenceImpl {} 57 private val defaultRoles = mapOf(ROLE_NAME to setOf(HOLDER_1, HOLDER_2)) 58 private val activeUserIds = mapOf(ROLE_NAME to USER_ID) 59 private val stateVersionUndefined = RolesState(VERSION_UNDEFINED, PACKAGE_HASH, defaultRoles) 60 private val stateVersionFallbackMigrated = 61 RolesState(VERSION_FALLBACK_MIGRATED, PACKAGE_HASH, defaultRoles, setOf(ROLE_NAME)) 62 private val stateVersionActiveUserIds = 63 RolesState( 64 VERSION_ACTIVE_USER_IDS, 65 PACKAGE_HASH, 66 defaultRoles, 67 setOf(ROLE_NAME), 68 activeUserIds, 69 ) 70 private val user = Process.myUserHandle() 71 72 @get:Rule val flagsRule = DeviceFlagsValueProvider.createCheckFlagsRule() 73 74 @Before setUpnull75 fun setUp() { 76 createMockDataDirectory() 77 mockApexEnvironment() 78 state = getState() 79 } 80 createMockDataDirectorynull81 private fun createMockDataDirectory() { 82 mockDataDirectory = context.getDir("mock_data", Context.MODE_PRIVATE) 83 mockDataDirectory.listFiles()!!.forEach { assertThat(it.deleteRecursively()).isTrue() } 84 } 85 mockApexEnvironmentnull86 private fun mockApexEnvironment() { 87 initMocks(this) 88 mockitoSession = 89 mockitoSession() 90 .mockStatic(ApexEnvironment::class.java) 91 .strictness(Strictness.LENIENT) 92 .startMocking() 93 `when`(ApexEnvironment.getApexEnvironment(eq(APEX_MODULE_NAME))).thenReturn(apexEnvironment) 94 `when`(apexEnvironment.getDeviceProtectedDataDirForUser(any(UserHandle::class.java))).then { 95 File(mockDataDirectory, it.arguments[0].toString()).also { it.mkdirs() } 96 } 97 } 98 99 @After finishMockingApexEnvironmentnull100 fun finishMockingApexEnvironment() { 101 mockitoSession.finishMocking() 102 } 103 104 @RequiresFlagsDisabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED) 105 @Test testWriteReadnull106 fun testWriteRead() { 107 assumeFalse(stateVersion == StateVersion.VERSION_ACTIVE_USER_IDS) 108 persistence.writeForUser(state, user) 109 val persistedState = persistence.readForUser(user) 110 111 assertThat(persistedState).isEqualTo(state) 112 } 113 114 @RequiresFlagsEnabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED) 115 @Test testWriteRead_supportsActiveUsernull116 fun testWriteRead_supportsActiveUser() { 117 persistence.writeForUser(state, user) 118 val persistedState = persistence.readForUser(user) 119 120 assertThat(persistedState).isEqualTo(state) 121 } 122 123 @RequiresFlagsDisabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED) 124 @Test testWriteCorruptReadFromReserveCopynull125 fun testWriteCorruptReadFromReserveCopy() { 126 assumeFalse(stateVersion == StateVersion.VERSION_ACTIVE_USER_IDS) 127 persistence.writeForUser(state, user) 128 // Corrupt the primary file. 129 RolesPersistenceImpl.getFile(user) 130 .writeText("<roles version=\"-1\"><role name=\"com.foo.bar\"><holder") 131 val persistedState = persistence.readForUser(user) 132 133 assertThat(persistedState).isEqualTo(state) 134 } 135 136 @RequiresFlagsEnabled(Flags.FLAG_CROSS_USER_ROLE_ENABLED) 137 @Test testWriteCorruptReadFromReserveCopy_supportsActiveUsernull138 fun testWriteCorruptReadFromReserveCopy_supportsActiveUser() { 139 persistence.writeForUser(state, user) 140 // Corrupt the primary file. 141 RolesPersistenceImpl.getFile(user) 142 .writeText("<roles version=\"-1\"><role name=\"com.foo.bar\"><holder") 143 val persistedState = persistence.readForUser(user) 144 145 assertThat(persistedState).isEqualTo(state) 146 } 147 148 @Test testDeletenull149 fun testDelete() { 150 persistence.writeForUser(state, user) 151 persistence.deleteForUser(user) 152 val persistedState = persistence.readForUser(user) 153 154 assertThat(persistedState).isNull() 155 } 156 getStatenull157 private fun getState(): RolesState = 158 when (stateVersion) { 159 StateVersion.VERSION_UNDEFINED -> stateVersionUndefined 160 StateVersion.VERSION_FALLBACK_MIGRATED -> stateVersionFallbackMigrated 161 StateVersion.VERSION_ACTIVE_USER_IDS -> stateVersionActiveUserIds 162 } 163 164 enum class StateVersion { 165 VERSION_UNDEFINED, 166 VERSION_FALLBACK_MIGRATED, 167 VERSION_ACTIVE_USER_IDS, 168 } 169 170 companion object { 171 @Parameterized.Parameters(name = "{0}") 172 @JvmStatic datanull173 fun data(): Array<StateVersion> = StateVersion.values() 174 175 private const val VERSION_UNDEFINED = -1 176 private const val VERSION_FALLBACK_MIGRATED = 1 177 private const val VERSION_ACTIVE_USER_IDS = 2 178 private const val APEX_MODULE_NAME = "com.android.permission" 179 private const val PACKAGE_HASH = "packagesHash" 180 private const val ROLE_NAME = "roleName" 181 private const val HOLDER_1 = "holder1" 182 private const val HOLDER_2 = "holder2" 183 private const val USER_ID = 10 184 } 185 } 186