1 //
2 // database.cpp
3 // kvstore
4 //
5 // Copyright © 2024 Apple Inc. All rights reserved.
6 //
7 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
8
9 #include <database.hpp>
10
11 #include <sqlite_error.hpp>
12
13 namespace {
14 using namespace executorchcoreml::sqlite;
15
getPreparedStatement(sqlite3 * database,const std::string & statement,std::error_code & error)16 sqlite3_stmt* getPreparedStatement(sqlite3* database, const std::string& statement, std::error_code& error) {
17 sqlite3_stmt* handle = nullptr;
18 const int status = sqlite3_prepare_v2(database, statement.c_str(), -1, &handle, nullptr);
19 if (!process_sqlite_status(status, error)) {
20 return nullptr;
21 }
22
23 return handle;
24 }
25
26 /// Returns the sqlite pragma from
toString(Database::SynchronousMode mode)27 std::string toString(Database::SynchronousMode mode) {
28 switch (mode) {
29 case Database::SynchronousMode::Full: {
30 return "FULL";
31 }
32 case Database::SynchronousMode::Extra: {
33 return "EXTRA";
34 }
35 case Database::SynchronousMode::Normal: {
36 return "NORMAL";
37 }
38 case Database::SynchronousMode::Off: {
39 return "OFF";
40 }
41 }
42 }
43
44 /// Returns the sqlite statement for a specified transaction behavior.
getTransactionStatement(Database::TransactionBehavior behavior)45 std::string getTransactionStatement(Database::TransactionBehavior behavior) {
46 switch (behavior) {
47 case Database::TransactionBehavior::Deferred: {
48 return "BEGIN DEFERRED";
49 }
50 case Database::TransactionBehavior::Immediate: {
51 return "BEGIN IMMEDIATE";
52 }
53 case Database::TransactionBehavior::Exclusive: {
54 return "BEGIN EXCLUSIVE";
55 }
56 }
57 }
58 } // namespace
59
60 namespace executorchcoreml {
61 namespace sqlite {
62
get_sqlite_flags() const63 int Database::OpenOptions::get_sqlite_flags() const noexcept {
64 int flags = 0;
65 if (is_read_only_option_enabled()) {
66 flags |= SQLITE_OPEN_READONLY;
67 }
68
69 if (is_read_write_option_enabled()) {
70 flags |= SQLITE_OPEN_READWRITE;
71 }
72
73 if (is_create_option_enabled()) {
74 flags |= SQLITE_OPEN_CREATE;
75 }
76
77 if (is_memory_option_enabled()) {
78 flags |= SQLITE_OPEN_MEMORY;
79 }
80
81 if (is_no_mutex_option_enabled()) {
82 flags |= SQLITE_OPEN_NOMUTEX;
83 }
84
85 if (is_full_mutex_option_enabled()) {
86 flags |= SQLITE_OPEN_FULLMUTEX;
87 }
88
89 if (is_shared_cache_option_enabled()) {
90 flags |= SQLITE_OPEN_SHAREDCACHE;
91 }
92
93 if (is_shared_cache_option_enabled()) {
94 flags |= SQLITE_OPEN_SHAREDCACHE;
95 }
96
97 if (is_uri_option_enabled()) {
98 flags |= SQLITE_OPEN_URI;
99 }
100
101 return flags;
102 }
103
open(OpenOptions options,SynchronousMode mode,int busy_timeout_ms,std::error_code & error)104 bool Database::open(OpenOptions options, SynchronousMode mode, int busy_timeout_ms, std::error_code& error) noexcept {
105 sqlite3* handle = nullptr;
106 const int status = sqlite3_open_v2(file_path_.c_str(), &handle, options.get_sqlite_flags(), nullptr);
107 sqlite_database_.reset(handle);
108 if (!process_sqlite_status(status, error)) {
109 return false;
110 }
111
112 if (!set_busy_timeout(busy_timeout_ms, error)) {
113 return false;
114 }
115
116 if (!execute("pragma journal_mode = WAL", error)) {
117 return false;
118 }
119
120 if (!execute("pragma auto_vacuum = FULL", error)) {
121 return false;
122 }
123
124 if (!execute("pragma synchronous = " + toString(mode), error)) {
125 return false;
126 }
127
128 return true;
129 }
130
is_open() const131 bool Database::is_open() const noexcept { return sqlite_database_ != nullptr; }
132
table_exists(const std::string & tableName,std::error_code & error) const133 bool Database::table_exists(const std::string& tableName, std::error_code& error) const noexcept {
134 auto statement = prepare_statement("SELECT COUNT(*) FROM sqlite_master WHERE TYPE='table' AND NAME=?", error);
135 if (!statement) {
136 return false;
137 }
138
139 if (!statement->bind(1, UnOwnedString(tableName), error)) {
140 return false;
141 }
142
143 if (!statement->step(error)) {
144 return false;
145 }
146
147 auto value = statement->get_column_value(0, error);
148 if (error) {
149 return false;
150 }
151
152 return (std::get<int64_t>(value) == 1);
153 }
154
drop_table(const std::string & tableName,std::error_code & error) const155 bool Database::drop_table(const std::string& tableName, std::error_code& error) const noexcept {
156 std::string statement = "DROP TABLE IF EXISTS " + tableName;
157 return execute(statement, error);
158 }
159
get_row_count(const std::string & tableName,std::error_code & error) const160 int64_t Database::get_row_count(const std::string& tableName, std::error_code& error) const noexcept {
161 auto statement = prepare_statement("SELECT COUNT(*) FROM " + tableName, error);
162 if (!statement) {
163 return -1;
164 }
165
166 if (!statement->step(error)) {
167 return -1;
168 }
169
170 auto value = statement->get_column_value(0, error);
171 return std::get<int64_t>(value);
172 }
173
set_busy_timeout(int busy_timeout_ms,std::error_code & error) const174 bool Database::set_busy_timeout(int busy_timeout_ms, std::error_code& error) const noexcept {
175 const int status = sqlite3_busy_timeout(get_underlying_database(), busy_timeout_ms);
176 return process_sqlite_status(status, error);
177 }
178
execute(const std::string & statements,std::error_code & error) const179 bool Database::execute(const std::string& statements, std::error_code& error) const noexcept {
180 const int status = sqlite3_exec(get_underlying_database(), statements.c_str(), nullptr, nullptr, nullptr);
181 return process_sqlite_status(status, error);
182 }
183
get_updated_row_count() const184 int Database::get_updated_row_count() const noexcept { return sqlite3_changes(get_underlying_database()); }
185
get_last_error_message() const186 std::string Database::get_last_error_message() const noexcept { return sqlite3_errmsg(get_underlying_database()); }
187
prepare_statement(const std::string & statement,std::error_code & error) const188 std::unique_ptr<PreparedStatement> Database::prepare_statement(const std::string& statement,
189 std::error_code& error) const noexcept {
190 sqlite3_stmt* handle = getPreparedStatement(get_underlying_database(), statement, error);
191 return std::make_unique<PreparedStatement>(std::unique_ptr<sqlite3_stmt, StatementDeleter>(handle));
192 }
193
get_last_inserted_row_id() const194 int64_t Database::get_last_inserted_row_id() const noexcept {
195 return sqlite3_last_insert_rowid(get_underlying_database());
196 }
197
get_last_error_code() const198 std::error_code Database::get_last_error_code() const noexcept {
199 int code = sqlite3_errcode(get_underlying_database());
200 return static_cast<ErrorCode>(code);
201 }
202
get_last_extended_error_code() const203 std::error_code Database::get_last_extended_error_code() const noexcept {
204 int code = sqlite3_extended_errcode(get_underlying_database());
205 return static_cast<ErrorCode>(code);
206 }
207
begin_transaction(TransactionBehavior behavior,std::error_code & error) const208 bool Database::begin_transaction(TransactionBehavior behavior, std::error_code& error) const noexcept {
209 return execute(getTransactionStatement(behavior), error);
210 }
211
commit_transaction(std::error_code & error) const212 bool Database::commit_transaction(std::error_code& error) const noexcept {
213 return execute("COMMIT TRANSACTION", error);
214 }
215
rollback_transaction(std::error_code & error) const216 bool Database::rollback_transaction(std::error_code& error) const noexcept {
217 return execute("ROLLBACK TRANSACTION", error);
218 }
219
transaction(const std::function<bool (void)> & fn,TransactionBehavior behavior,std::error_code & error)220 bool Database::transaction(const std::function<bool(void)>& fn,
221 TransactionBehavior behavior,
222 std::error_code& error) noexcept {
223 if (!begin_transaction(behavior, error)) {
224 return false;
225 }
226
227 bool status = fn();
228 if (status) {
229 return commit_transaction(error);
230 } else {
231 rollback_transaction(error);
232 return false;
233 }
234 }
235
make_inmemory(SynchronousMode mode,int busy_timeout_ms,std::error_code & error)236 std::shared_ptr<Database> Database::make_inmemory(SynchronousMode mode, int busy_timeout_ms, std::error_code& error) {
237 auto database = std::make_shared<Database>(":memory:");
238 OpenOptions options;
239 options.set_read_write_option(true);
240 if (database->open(options, mode, busy_timeout_ms, error)) {
241 return database;
242 }
243
244 return nullptr;
245 }
246
make(const std::string & filePath,OpenOptions options,SynchronousMode mode,int busy_timeout_ms,std::error_code & error)247 std::shared_ptr<Database> Database::make(const std::string& filePath,
248 OpenOptions options,
249 SynchronousMode mode,
250 int busy_timeout_ms,
251 std::error_code& error) {
252 auto database = std::make_shared<Database>(filePath);
253 if (database->open(options, mode, busy_timeout_ms, error)) {
254 return database;
255 }
256
257 return nullptr;
258 }
259
260 } // namespace sqlite
261 } // namespace executorchcoreml
262