1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.common;
18 
19 import static org.junit.Assert.assertArrayEquals;
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertFalse;
22 
23 import android.content.Context;
24 import android.database.Cursor;
25 import android.database.sqlite.SQLiteDatabase;
26 import android.util.Log;
27 
28 import androidx.test.core.app.ApplicationProvider;
29 
30 import com.android.adservices.data.DbHelper;
31 import com.android.adservices.data.measurement.MeasurementDbHelper;
32 import com.android.adservices.data.shared.SharedDbHelper;
33 
34 import com.google.common.collect.ImmutableSet;
35 
36 import java.util.ArrayList;
37 import java.util.Collections;
38 import java.util.List;
39 import java.util.Set;
40 
41 public final class DbTestUtil {
42     private static final Context sContext = ApplicationProvider.getApplicationContext();
43     private static final String DATABASE_NAME_FOR_TEST = "adservices_test.db";
44     private static final String MSMT_DATABASE_NAME_FOR_TEST = "adservices_msmt_test.db";
45     private static final String SHARED_DATABASE_NAME_FOR_TEST = "adservices_shared_test.db";
46 
47     private static DbHelper sSingleton;
48     private static MeasurementDbHelper sMsmtSingleton;
49     private static SharedDbHelper sSharedSingleton;
50 
51     /** Erases all data from the table rows */
deleteTable(String tableName)52     public static void deleteTable(String tableName) {
53         try (SQLiteDatabase db = getDbHelperForTest().safeGetWritableDatabase()) {
54             if (db == null) {
55                 return;
56             }
57 
58             db.delete(tableName, /* whereClause= */ null, /* whereArgs= */ null);
59         }
60     }
61 
62     /**
63      * Create an instance of database instance for testing.
64      *
65      * @return a test database
66      */
getDbHelperForTest()67     public static DbHelper getDbHelperForTest() {
68         synchronized (DbHelper.class) {
69             if (sSingleton == null) {
70                 sSingleton =
71                         new DbHelper(sContext, DATABASE_NAME_FOR_TEST, DbHelper.DATABASE_VERSION_7);
72             }
73             return sSingleton;
74         }
75     }
76 
getMeasurementDbHelperForTest()77     public static MeasurementDbHelper getMeasurementDbHelperForTest() {
78         synchronized (MeasurementDbHelper.class) {
79             if (sMsmtSingleton == null) {
80                 sMsmtSingleton =
81                         new MeasurementDbHelper(
82                                 sContext,
83                                 MSMT_DATABASE_NAME_FOR_TEST,
84                                 MeasurementDbHelper.CURRENT_DATABASE_VERSION,
85                                 getDbHelperForTest());
86             }
87             return sMsmtSingleton;
88         }
89     }
90 
getSharedDbHelperForTest()91     public static SharedDbHelper getSharedDbHelperForTest() {
92         synchronized (SharedDbHelper.class) {
93             if (sSharedSingleton == null) {
94                 sSharedSingleton =
95                         new SharedDbHelper(
96                                 sContext,
97                                 SHARED_DATABASE_NAME_FOR_TEST,
98                                 SharedDbHelper.DATABASE_VERSION_V4,
99                                 getDbHelperForTest());
100             }
101             return sSharedSingleton;
102         }
103     }
104 
105     /** Return true if table exists in the DB and column count matches. */
doesTableExistAndColumnCountMatch( SQLiteDatabase db, String tableName, int columnCount)106     public static boolean doesTableExistAndColumnCountMatch(
107             SQLiteDatabase db, String tableName, int columnCount) {
108         final Set<String> tableColumns = getTableColumns(db, tableName);
109         int actualCol = tableColumns.size();
110         Log.d("DbTestUtil_log_test,", " table name: " + tableName + " column count: " + actualCol);
111         return tableColumns.size() == columnCount;
112     }
113 
114     /** Returns column names of the table. */
getTableColumns(SQLiteDatabase db, String tableName)115     public static Set<String> getTableColumns(SQLiteDatabase db, String tableName) {
116         String query =
117                 "select p.name from sqlite_master s "
118                         + "join pragma_table_info(s.name) p "
119                         + "where s.tbl_name = '"
120                         + tableName
121                         + "'";
122         Cursor cursor = db.rawQuery(query, null);
123         if (cursor == null) {
124             throw new IllegalArgumentException("Cursor is null.");
125         }
126 
127         ImmutableSet.Builder<String> tableColumns = ImmutableSet.builder();
128         while (cursor.moveToNext()) {
129             tableColumns.add(cursor.getString(0));
130         }
131 
132         return tableColumns.build();
133     }
134 
135     /** Return true if the given index exists in the DB. */
doesIndexExist(SQLiteDatabase db, String index)136     public static boolean doesIndexExist(SQLiteDatabase db, String index) {
137         String query = "SELECT * FROM sqlite_master WHERE type='index' and name='" + index + "'";
138         Cursor cursor = db.rawQuery(query, null);
139         return cursor != null && cursor.getCount() > 0;
140     }
141 
doesTableExist(SQLiteDatabase db, String table)142     public static boolean doesTableExist(SQLiteDatabase db, String table) {
143         String query = "SELECT * FROM sqlite_master WHERE type='table' and name='" + table + "'";
144         Cursor cursor = db.rawQuery(query, null);
145         return cursor != null && cursor.getCount() > 0;
146     }
147 
assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb)148     public static void assertDatabasesEqual(SQLiteDatabase expectedDb, SQLiteDatabase actualDb) {
149         List<String> expectedTables = getTables(expectedDb);
150         List<String> actualTables = getTables(actualDb);
151         assertArrayEquals(expectedTables.toArray(), actualTables.toArray());
152         assertTableSchemaEqual(expectedDb, actualDb, expectedTables);
153         assertIndexesEqual(expectedDb, actualDb, expectedTables);
154     }
155 
assertMeasurementTablesDoNotExist(SQLiteDatabase db)156     public static void assertMeasurementTablesDoNotExist(SQLiteDatabase db) {
157         assertFalse(doesTableExist(db, "msmt_source"));
158         assertFalse(doesTableExist(db, "msmt_trigger"));
159         assertFalse(doesTableExist(db, "msmt_async_registration_contract"));
160         assertFalse(doesTableExist(db, "msmt_event_report"));
161         assertFalse(doesTableExist(db, "msmt_attribution"));
162         assertFalse(doesTableExist(db, "msmt_aggregate_report"));
163         assertFalse(doesTableExist(db, "msmt_aggregate_encryption_key"));
164         assertFalse(doesTableExist(db, "msmt_debug_report"));
165         assertFalse(doesTableExist(db, "msmt_xna_ignored_sources"));
166     }
167 
assertTableSchemaEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames)168     private static void assertTableSchemaEqual(
169             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tableNames) {
170         for (String tableName : tableNames) {
171             Cursor columnsCursorExpected =
172                     expectedDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null);
173             Cursor columnsCursorActual =
174                     actualDb.rawQuery("PRAGMA TABLE_INFO(" + tableName + ")", null);
175             assertEquals(
176                     "Table columns mismatch for " + tableName,
177                     columnsCursorExpected.getCount(),
178                     columnsCursorActual.getCount());
179 
180             // Checks the columns in order. Newly created columns should be inserted as the end.
181             while (columnsCursorExpected.moveToNext() && columnsCursorActual.moveToNext()) {
182                 assertEquals(
183                         "Column mismatch for " + tableName,
184                         columnsCursorExpected.getString(
185                                 columnsCursorExpected.getColumnIndex("name")),
186                         columnsCursorActual.getString(columnsCursorActual.getColumnIndex("name")));
187                 assertEquals(
188                         "Column mismatch for " + tableName,
189                         columnsCursorExpected.getString(
190                                 columnsCursorExpected.getColumnIndex("type")),
191                         columnsCursorActual.getString(columnsCursorActual.getColumnIndex("type")));
192                 assertEquals(
193                         "Column mismatch for " + tableName,
194                         columnsCursorExpected.getInt(
195                                 columnsCursorExpected.getColumnIndex("notnull")),
196                         columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("notnull")));
197                 assertEquals(
198                         "Column mismatch for " + tableName,
199                         columnsCursorExpected.getString(
200                                 columnsCursorExpected.getColumnIndex("dflt_value")),
201                         columnsCursorActual.getString(
202                                 columnsCursorActual.getColumnIndex("dflt_value")));
203                 assertEquals(
204                         "Column mismatch for " + tableName,
205                         columnsCursorExpected.getInt(columnsCursorExpected.getColumnIndex("pk")),
206                         columnsCursorActual.getInt(columnsCursorActual.getColumnIndex("pk")));
207             }
208 
209             columnsCursorExpected.close();
210             columnsCursorActual.close();
211         }
212     }
213 
getTables(SQLiteDatabase db)214     private static List<String> getTables(SQLiteDatabase db) {
215         String listTableQuery = "SELECT name FROM sqlite_master where type = 'table'";
216         List<String> tables = new ArrayList<>();
217         try (Cursor cursor = db.rawQuery(listTableQuery, null)) {
218             while (cursor.moveToNext()) {
219                 tables.add(cursor.getString(cursor.getColumnIndex("name")));
220             }
221         }
222         Collections.sort(tables);
223         return tables;
224     }
225 
assertIndexesEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables)226     private static void assertIndexesEqual(
227             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, List<String> tables) {
228         for (String tableName : tables) {
229             String indexListQuery =
230                     "SELECT name FROM sqlite_master where type = 'index' AND tbl_name = '"
231                             + tableName
232                             + "' ORDER BY name ASC";
233             Cursor indexListCursorExpected = expectedDb.rawQuery(indexListQuery, null);
234             Cursor indexListCursorActual = actualDb.rawQuery(indexListQuery, null);
235             assertEquals(
236                     "Table indexes mismatch for " + tableName,
237                     indexListCursorExpected.getCount(),
238                     indexListCursorActual.getCount());
239 
240             while (indexListCursorExpected.moveToNext() && indexListCursorActual.moveToNext()) {
241                 String expectedIndexName =
242                         indexListCursorExpected.getString(
243                                 indexListCursorExpected.getColumnIndex("name"));
244                 assertEquals(
245                         "Index mismatch for " + tableName,
246                         expectedIndexName,
247                         indexListCursorActual.getString(
248                                 indexListCursorActual.getColumnIndex("name")));
249 
250                 assertIndexInfoEqual(expectedDb, actualDb, expectedIndexName);
251             }
252 
253             indexListCursorExpected.close();
254             indexListCursorActual.close();
255         }
256     }
257 
assertIndexInfoEqual( SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName)258     private static void assertIndexInfoEqual(
259             SQLiteDatabase expectedDb, SQLiteDatabase actualDb, String indexName) {
260         Cursor indexInfoCursorExpected =
261                 expectedDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null);
262         Cursor indexInfoCursorActual =
263                 actualDb.rawQuery("PRAGMA main.INDEX_INFO (" + indexName + ")", null);
264         assertEquals(
265                 "Index columns count mismatch for " + indexName,
266                 indexInfoCursorExpected.getCount(),
267                 indexInfoCursorActual.getCount());
268 
269         while (indexInfoCursorExpected.moveToNext() && indexInfoCursorActual.moveToNext()) {
270             assertEquals(
271                     "Index info mismatch for " + indexName,
272                     indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("seqno")),
273                     indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("seqno")));
274             assertEquals(
275                     "Index info mismatch for " + indexName,
276                     indexInfoCursorExpected.getInt(indexInfoCursorExpected.getColumnIndex("cid")),
277                     indexInfoCursorActual.getInt(indexInfoCursorActual.getColumnIndex("cid")));
278             assertEquals(
279                     "Index info mismatch for " + indexName,
280                     indexInfoCursorExpected.getString(
281                             indexInfoCursorExpected.getColumnIndex("name")),
282                     indexInfoCursorActual.getString(indexInfoCursorActual.getColumnIndex("name")));
283         }
284 
285         indexInfoCursorExpected.close();
286         indexInfoCursorActual.close();
287     }
288 }
289