1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.common; 18 19 import static org.junit.Assert.assertArrayEquals; 20 import static org.junit.Assert.assertEquals; 21 import static org.junit.Assert.assertFalse; 22 23 import android.content.Context; 24 import android.database.Cursor; 25 import android.database.sqlite.SQLiteDatabase; 26 import android.util.Log; 27 28 import androidx.test.core.app.ApplicationProvider; 29 30 import com.android.adservices.data.DbHelper; 31 import com.android.adservices.data.measurement.MeasurementDbHelper; 32 import com.android.adservices.data.shared.SharedDbHelper; 33 34 import com.google.common.collect.ImmutableSet; 35 36 import java.util.ArrayList; 37 import java.util.Collections; 38 import java.util.List; 39 import java.util.Set; 40 41 public final class DbTestUtil { 42 private static final Context sContext = ApplicationProvider.getApplicationContext(); 43 private static final String DATABASE_NAME_FOR_TEST = "adservices_test.db"; 44 private static final String MSMT_DATABASE_NAME_FOR_TEST = "adservices_msmt_test.db"; 45 private static final String SHARED_DATABASE_NAME_FOR_TEST = "adservices_shared_test.db"; 46 47 private static DbHelper sSingleton; 48 private static MeasurementDbHelper sMsmtSingleton; 49 private static SharedDbHelper sSharedSingleton; 50 51 /** Erases all data from the table rows */ deleteTable(String tableName)52 public static void deleteTable(String tableName) { 53 try (SQLiteDatabase db = getDbHelperForTest().safeGetWritableDatabase()) { 54 if (db == null) { 55 return; 56 } 57 58 db.delete(tableName, /* whereClause= */ null, /* whereArgs= */ null); 59 } 60 } 61 62 /** 63 * Create an instance of database instance for testing. 64 * 65 * @return a test database 66 */ getDbHelperForTest()67 public static DbHelper getDbHelperForTest() { 68 synchronized (DbHelper.class) { 69 if (sSingleton == null) { 70 sSingleton = 71 new DbHelper(sContext, DATABASE_NAME_FOR_TEST, DbHelper.DATABASE_VERSION_7); 72 } 73 return sSingleton; 74 } 75 } 76 getMeasurementDbHelperForTest()77 public static MeasurementDbHelper getMeasurementDbHelperForTest() { 78 synchronized (MeasurementDbHelper.class) { 79 if (sMsmtSingleton == null) { 80 sMsmtSingleton = 81 new MeasurementDbHelper( 82 sContext, 83 MSMT_DATABASE_NAME_FOR_TEST, 84 MeasurementDbHelper.CURRENT_DATABASE_VERSION, 85 getDbHelperForTest()); 86 } 87 return sMsmtSingleton; 88 } 89 } 90 getSharedDbHelperForTest()91 public static SharedDbHelper getSharedDbHelperForTest() { 92 synchronized (SharedDbHelper.class) { 93 if (sSharedSingleton == null) { 94 sSharedSingleton = 95 new SharedDbHelper( 96 sContext, 97 SHARED_DATABASE_NAME_FOR_TEST, 98 SharedDbHelper.DATABASE_VERSION_V4, 99 getDbHelperForTest()); 100 } 101 return sSharedSingleton; 102 } 103 } 104 105 /** Return true if table exists in the DB and column count matches. */ doesTableExistAndColumnCountMatch( SQLiteDatabase db, String tableName, int columnCount)106 public static boolean doesTableExistAndColumnCountMatch( 107 SQLiteDatabase db, String tableName, int columnCount) { 108 final Set<String> tableColumns = getTableColumns(db, tableName); 109 int actualCol = tableColumns.size(); 110 Log.d("DbTestUtil_log_test,", " table name: " + tableName + " column count: " + actualCol); 111 return tableColumns.size() == columnCount; 112 } 113 114 /** Returns column names of the table. */ getTableColumns(SQLiteDatabase db, String tableName)115 public static Set<String> getTableColumns(SQLiteDatabase db, String tableName) { 116 String query = 117 "select p.name from sqlite_master s " 118 + "join pragma_table_info(s.name) p " 119 + "where s.tbl_name = '" 120 + tableName 121 + "'"; 122 Cursor cursor = db.rawQuery(query, null); 123 if (cursor == null) { 124 throw new IllegalArgumentException("Cursor is null."); 125 } 126 127 ImmutableSet.Builder<String> tableColumns = ImmutableSet.builder(); 128 while (cursor.moveToNext()) { 129 tableColumns.add(cursor.getString(0)); 130 } 131 132 return tableColumns.build(); 133 } 134 135 /** Return true if the given index exists in the DB. */ doesIndexExist(SQLiteDatabase db, String index)136 public static boolean doesIndexExist(SQLiteDatabase db, String index) { 137 String query = "SELECT * FROM sqlite_master WHERE type='index' and name='" + index + "'"; 138 Cursor cursor = db.rawQuery(query, null); 139 return cursor != null && cursor.getCount() > 0; 140 } 141 doesTableExist(SQLiteDatabase db, String table)142 public static boolean doesTableExist(SQLiteDatabase db, String table) { 143 String query = "SELECT * FROM sqlite_master WHERE type='table' and name='" + table + "'"; 144 Cursor cursor = db.rawQuery(query, null); 145 return cursor != null && cursor.getCount() > 0; 146 } 147 assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb)148 public static void assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb) { 149 List<String> expectedTables = getTables(expectedDb); 150 List<String> actualTables = getTables(actualDb); 151 assertArrayEquals(expectedTables.toArray(), actualTables.toArray()); 152 assertTableSchemaEqual(expectedDb, actualDb, expectedTables); 153 assertIndexesEqual(expectedDb, actualDb, expectedTables); 154 } 155 assertMeasurementTablesDoNotExist(SQLiteDatabase db)156 public static void assertMeasurementTablesDoNotExist(SQLiteDatabase db) { 157 assertFalse(doesTableExist(db, "msmt_source")); 158 assertFalse(doesTableExist(db, "msmt_trigger")); 159 assertFalse(doesTableExist(db, "msmt_async_registration_contract")); 160 assertFalse(doesTableExist(db, "msmt_event_report")); 161 assertFalse(doesTableExist(db, "msmt_attribution")); 162 assertFalse(doesTableExist(db, "msmt_aggregate_report")); 163 assertFalse(doesTableExist(db, "msmt_aggregate_encryption_key")); 164 assertFalse(doesTableExist(db, "msmt_debug_report")); 165 assertFalse(doesTableExist(db, "msmt_xna_ignored_sources")); 166 } 167 assertTableSchemaEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames)168 private static void assertTableSchemaEqual( 169 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames) { 170 for (String tableName : tableNames) { 171 Cursor columnsCursorExpected = 172 expectedDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null); 173 Cursor columnsCursorActual = 174 actualDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null); 175 assertEquals( 176 "Table columns mismatch for " + tableName, 177 columnsCursorExpected.getCount(), 178 columnsCursorActual.getCount()); 179 180 // Checks the columns in order. Newly created columns should be inserted as the end. 181 while (columnsCursorExpected.moveToNext() && columnsCursorActual.moveToNext()) { 182 assertEquals( 183 "Column mismatch for " + tableName, 184 columnsCursorExpected.getString( 185 columnsCursorExpected.getColumnIndex("name")), 186 columnsCursorActual.getString(columnsCursorActual.getColumnIndex("name"))); 187 assertEquals( 188 "Column mismatch for " + tableName, 189 columnsCursorExpected.getString( 190 columnsCursorExpected.getColumnIndex("type")), 191 columnsCursorActual.getString(columnsCursorActual.getColumnIndex("type"))); 192 assertEquals( 193 "Column mismatch for " + tableName, 194 columnsCursorExpected.getInt( 195 columnsCursorExpected.getColumnIndex("notnull")), 196 columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("notnull"))); 197 assertEquals( 198 "Column mismatch for " + tableName, 199 columnsCursorExpected.getString( 200 columnsCursorExpected.getColumnIndex("dflt_value")), 201 columnsCursorActual.getString( 202 columnsCursorActual.getColumnIndex("dflt_value"))); 203 assertEquals( 204 "Column mismatch for " + tableName, 205 columnsCursorExpected.getInt(columnsCursorExpected.getColumnIndex("pk")), 206 columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("pk"))); 207 } 208 209 columnsCursorExpected.close(); 210 columnsCursorActual.close(); 211 } 212 } 213 getTables(SQLiteDatabase db)214 private static List<String> getTables(SQLiteDatabase db) { 215 String listTableQuery = "SELECT name FROM sqlite_master where type = 'table'"; 216 List<String> tables = new ArrayList<>(); 217 try (Cursor cursor = db.rawQuery(listTableQuery, null)) { 218 while (cursor.moveToNext()) { 219 tables.add(cursor.getString(cursor.getColumnIndex("name"))); 220 } 221 } 222 Collections.sort(tables); 223 return tables; 224 } 225 assertIndexesEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables)226 private static void assertIndexesEqual( 227 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables) { 228 for (String tableName : tables) { 229 String indexListQuery = 230 "SELECT name FROM sqlite_master where type = 'index' AND tbl_name = '" 231 + tableName 232 + "' ORDER BY name ASC"; 233 Cursor indexListCursorExpected = expectedDb.rawQuery(indexListQuery, null); 234 Cursor indexListCursorActual = actualDb.rawQuery(indexListQuery, null); 235 assertEquals( 236 "Table indexes mismatch for " + tableName, 237 indexListCursorExpected.getCount(), 238 indexListCursorActual.getCount()); 239 240 while (indexListCursorExpected.moveToNext() && indexListCursorActual.moveToNext()) { 241 String expectedIndexName = 242 indexListCursorExpected.getString( 243 indexListCursorExpected.getColumnIndex("name")); 244 assertEquals( 245 "Index mismatch for " + tableName, 246 expectedIndexName, 247 indexListCursorActual.getString( 248 indexListCursorActual.getColumnIndex("name"))); 249 250 assertIndexInfoEqual(expectedDb, actualDb, expectedIndexName); 251 } 252 253 indexListCursorExpected.close(); 254 indexListCursorActual.close(); 255 } 256 } 257 assertIndexInfoEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName)258 private static void assertIndexInfoEqual( 259 SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName) { 260 Cursor indexInfoCursorExpected = 261 expectedDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null); 262 Cursor indexInfoCursorActual = 263 actualDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null); 264 assertEquals( 265 "Index columns count mismatch for " + indexName, 266 indexInfoCursorExpected.getCount(), 267 indexInfoCursorActual.getCount()); 268 269 while (indexInfoCursorExpected.moveToNext() && indexInfoCursorActual.moveToNext()) { 270 assertEquals( 271 "Index info mismatch for " + indexName, 272 indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("seqno")), 273 indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("seqno"))); 274 assertEquals( 275 "Index info mismatch for " + indexName, 276 indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("cid")), 277 indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("cid"))); 278 assertEquals( 279 "Index info mismatch for " + indexName, 280 indexInfoCursorExpected.getString( 281 indexInfoCursorExpected.getColumnIndex("name")), 282 indexInfoCursorActual.getString(indexInfoCursorActual.getColumnIndex("name"))); 283 } 284 285 indexInfoCursorExpected.close(); 286 indexInfoCursorActual.close(); 287 } 288 } 289