xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/kvstore/database.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 //  database.cpp
3 //  kvstore
4 //
5 // Copyright © 2024 Apple Inc. All rights reserved.
6 //
7 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
8 
9 #include <database.hpp>
10 
11 #include <sqlite_error.hpp>
12 
13 namespace {
14 using namespace executorchcoreml::sqlite;
15 
getPreparedStatement(sqlite3 * database,const std::string & statement,std::error_code & error)16 sqlite3_stmt* getPreparedStatement(sqlite3* database, const std::string& statement, std::error_code& error) {
17     sqlite3_stmt* handle = nullptr;
18     const int status = sqlite3_prepare_v2(database, statement.c_str(), -1, &handle, nullptr);
19     if (!process_sqlite_status(status, error)) {
20         return nullptr;
21     }
22 
23     return handle;
24 }
25 
26 /// Returns the sqlite pragma from
toString(Database::SynchronousMode mode)27 std::string toString(Database::SynchronousMode mode) {
28     switch (mode) {
29         case Database::SynchronousMode::Full: {
30             return "FULL";
31         }
32         case Database::SynchronousMode::Extra: {
33             return "EXTRA";
34         }
35         case Database::SynchronousMode::Normal: {
36             return "NORMAL";
37         }
38         case Database::SynchronousMode::Off: {
39             return "OFF";
40         }
41     }
42 }
43 
44 /// Returns the sqlite statement for a specified transaction behavior.
getTransactionStatement(Database::TransactionBehavior behavior)45 std::string getTransactionStatement(Database::TransactionBehavior behavior) {
46     switch (behavior) {
47         case Database::TransactionBehavior::Deferred: {
48             return "BEGIN DEFERRED";
49         }
50         case Database::TransactionBehavior::Immediate: {
51             return "BEGIN IMMEDIATE";
52         }
53         case Database::TransactionBehavior::Exclusive: {
54             return "BEGIN EXCLUSIVE";
55         }
56     }
57 }
58 } // namespace
59 
60 namespace executorchcoreml {
61 namespace sqlite {
62 
get_sqlite_flags() const63 int Database::OpenOptions::get_sqlite_flags() const noexcept {
64     int flags = 0;
65     if (is_read_only_option_enabled()) {
66         flags |= SQLITE_OPEN_READONLY;
67     }
68 
69     if (is_read_write_option_enabled()) {
70         flags |= SQLITE_OPEN_READWRITE;
71     }
72 
73     if (is_create_option_enabled()) {
74         flags |= SQLITE_OPEN_CREATE;
75     }
76 
77     if (is_memory_option_enabled()) {
78         flags |= SQLITE_OPEN_MEMORY;
79     }
80 
81     if (is_no_mutex_option_enabled()) {
82         flags |= SQLITE_OPEN_NOMUTEX;
83     }
84 
85     if (is_full_mutex_option_enabled()) {
86         flags |= SQLITE_OPEN_FULLMUTEX;
87     }
88 
89     if (is_shared_cache_option_enabled()) {
90         flags |= SQLITE_OPEN_SHAREDCACHE;
91     }
92 
93     if (is_shared_cache_option_enabled()) {
94         flags |= SQLITE_OPEN_SHAREDCACHE;
95     }
96 
97     if (is_uri_option_enabled()) {
98         flags |= SQLITE_OPEN_URI;
99     }
100 
101     return flags;
102 }
103 
open(OpenOptions options,SynchronousMode mode,int busy_timeout_ms,std::error_code & error)104 bool Database::open(OpenOptions options, SynchronousMode mode, int busy_timeout_ms, std::error_code& error) noexcept {
105     sqlite3* handle = nullptr;
106     const int status = sqlite3_open_v2(file_path_.c_str(), &handle, options.get_sqlite_flags(), nullptr);
107     sqlite_database_.reset(handle);
108     if (!process_sqlite_status(status, error)) {
109         return false;
110     }
111 
112     if (!set_busy_timeout(busy_timeout_ms, error)) {
113         return false;
114     }
115 
116     if (!execute("pragma journal_mode = WAL", error)) {
117         return false;
118     }
119 
120     if (!execute("pragma auto_vacuum = FULL", error)) {
121         return false;
122     }
123 
124     if (!execute("pragma synchronous = " + toString(mode), error)) {
125         return false;
126     }
127 
128     return true;
129 }
130 
is_open() const131 bool Database::is_open() const noexcept { return sqlite_database_ != nullptr; }
132 
table_exists(const std::string & tableName,std::error_code & error) const133 bool Database::table_exists(const std::string& tableName, std::error_code& error) const noexcept {
134     auto statement = prepare_statement("SELECT COUNT(*) FROM sqlite_master WHERE TYPE='table' AND NAME=?", error);
135     if (!statement) {
136         return false;
137     }
138 
139     if (!statement->bind(1, UnOwnedString(tableName), error)) {
140         return false;
141     }
142 
143     if (!statement->step(error)) {
144         return false;
145     }
146 
147     auto value = statement->get_column_value(0, error);
148     if (error) {
149         return false;
150     }
151 
152     return (std::get<int64_t>(value) == 1);
153 }
154 
drop_table(const std::string & tableName,std::error_code & error) const155 bool Database::drop_table(const std::string& tableName, std::error_code& error) const noexcept {
156     std::string statement = "DROP TABLE IF EXISTS " + tableName;
157     return execute(statement, error);
158 }
159 
get_row_count(const std::string & tableName,std::error_code & error) const160 int64_t Database::get_row_count(const std::string& tableName, std::error_code& error) const noexcept {
161     auto statement = prepare_statement("SELECT COUNT(*) FROM " + tableName, error);
162     if (!statement) {
163         return -1;
164     }
165 
166     if (!statement->step(error)) {
167         return -1;
168     }
169 
170     auto value = statement->get_column_value(0, error);
171     return std::get<int64_t>(value);
172 }
173 
set_busy_timeout(int busy_timeout_ms,std::error_code & error) const174 bool Database::set_busy_timeout(int busy_timeout_ms, std::error_code& error) const noexcept {
175     const int status = sqlite3_busy_timeout(get_underlying_database(), busy_timeout_ms);
176     return process_sqlite_status(status, error);
177 }
178 
execute(const std::string & statements,std::error_code & error) const179 bool Database::execute(const std::string& statements, std::error_code& error) const noexcept {
180     const int status = sqlite3_exec(get_underlying_database(), statements.c_str(), nullptr, nullptr, nullptr);
181     return process_sqlite_status(status, error);
182 }
183 
get_updated_row_count() const184 int Database::get_updated_row_count() const noexcept { return sqlite3_changes(get_underlying_database()); }
185 
get_last_error_message() const186 std::string Database::get_last_error_message() const noexcept { return sqlite3_errmsg(get_underlying_database()); }
187 
prepare_statement(const std::string & statement,std::error_code & error) const188 std::unique_ptr<PreparedStatement> Database::prepare_statement(const std::string& statement,
189                                                                std::error_code& error) const noexcept {
190     sqlite3_stmt* handle = getPreparedStatement(get_underlying_database(), statement, error);
191     return std::make_unique<PreparedStatement>(std::unique_ptr<sqlite3_stmt, StatementDeleter>(handle));
192 }
193 
get_last_inserted_row_id() const194 int64_t Database::get_last_inserted_row_id() const noexcept {
195     return sqlite3_last_insert_rowid(get_underlying_database());
196 }
197 
get_last_error_code() const198 std::error_code Database::get_last_error_code() const noexcept {
199     int code = sqlite3_errcode(get_underlying_database());
200     return static_cast<ErrorCode>(code);
201 }
202 
get_last_extended_error_code() const203 std::error_code Database::get_last_extended_error_code() const noexcept {
204     int code = sqlite3_extended_errcode(get_underlying_database());
205     return static_cast<ErrorCode>(code);
206 }
207 
begin_transaction(TransactionBehavior behavior,std::error_code & error) const208 bool Database::begin_transaction(TransactionBehavior behavior, std::error_code& error) const noexcept {
209     return execute(getTransactionStatement(behavior), error);
210 }
211 
commit_transaction(std::error_code & error) const212 bool Database::commit_transaction(std::error_code& error) const noexcept {
213     return execute("COMMIT TRANSACTION", error);
214 }
215 
rollback_transaction(std::error_code & error) const216 bool Database::rollback_transaction(std::error_code& error) const noexcept {
217     return execute("ROLLBACK TRANSACTION", error);
218 }
219 
transaction(const std::function<bool (void)> & fn,TransactionBehavior behavior,std::error_code & error)220 bool Database::transaction(const std::function<bool(void)>& fn,
221                            TransactionBehavior behavior,
222                            std::error_code& error) noexcept {
223     if (!begin_transaction(behavior, error)) {
224         return false;
225     }
226 
227     bool status = fn();
228     if (status) {
229         return commit_transaction(error);
230     } else {
231         rollback_transaction(error);
232         return false;
233     }
234 }
235 
make_inmemory(SynchronousMode mode,int busy_timeout_ms,std::error_code & error)236 std::shared_ptr<Database> Database::make_inmemory(SynchronousMode mode, int busy_timeout_ms, std::error_code& error) {
237     auto database = std::make_shared<Database>(":memory:");
238     OpenOptions options;
239     options.set_read_write_option(true);
240     if (database->open(options, mode, busy_timeout_ms, error)) {
241         return database;
242     }
243 
244     return nullptr;
245 }
246 
make(const std::string & filePath,OpenOptions options,SynchronousMode mode,int busy_timeout_ms,std::error_code & error)247 std::shared_ptr<Database> Database::make(const std::string& filePath,
248                                          OpenOptions options,
249                                          SynchronousMode mode,
250                                          int busy_timeout_ms,
251                                          std::error_code& error) {
252     auto database = std::make_shared<Database>(filePath);
253     if (database->open(options, mode, busy_timeout_ms, error)) {
254         return database;
255     }
256 
257     return nullptr;
258 }
259 
260 } // namespace sqlite
261 } // namespace executorchcoreml
262