1 //
2 // KeyValueStore.cpp
3 //
4 // Copyright © 2024 Apple Inc. All rights reserved.
5 //
6 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8 #include "key_value_store.hpp"
9
10 #include <iostream>
11 #include <sstream>
12
13 namespace {
14 using namespace executorchcoreml::sqlite;
15
16 constexpr std::string_view kAccessCountColumnName = "ENTRY_ACCESS_COUNT";
17 constexpr std::string_view kAccessTimeColumnName = "ENTRY_ACCESS_TIME";
18
to_string(StorageType storage_type)19 std::string to_string(StorageType storage_type) {
20 switch (storage_type) {
21 case StorageType::Text: {
22 return "TEXT";
23 }
24 case StorageType::Integer: {
25 return "INTEGER";
26 }
27 case StorageType::Double: {
28 return "REAL";
29 }
30 case StorageType::Blob: {
31 return "BLOB";
32 }
33 case StorageType::Null: {
34 return "NULL";
35 }
36 }
37 }
38
39 std::string
get_create_store_statement(std::string_view store_name,StorageType key_storage_type,StorageType value_storage_type)40 get_create_store_statement(std::string_view store_name, StorageType key_storage_type, StorageType value_storage_type) {
41 std::stringstream ss;
42 ss << "CREATE TABLE IF NOT EXISTS ";
43 ss << store_name << " ";
44 ss << "(";
45 ss << "ENTRY_KEY " << to_string(key_storage_type) << "PRIMARY KEY UNIQUE, ";
46 ss << "ENTRY_VALUE " << to_string(value_storage_type) << ", ";
47 ss << "ENTRY_ACCESS_COUNT " << to_string(StorageType::Integer) << ", ";
48 ss << "ENTRY_ACCESS_TIME " << to_string(StorageType::Integer);
49 ss << ")";
50
51 return ss.str();
52 }
53
get_create_index_statement(std::string_view store_name,std::string_view column_name)54 std::string get_create_index_statement(std::string_view store_name, std::string_view column_name) {
55 std::stringstream ss;
56 ss << "CREATE INDEX IF NOT EXISTS " << column_name << "_INDEX" << " ON " << store_name << "(" << column_name << ")";
57
58 return ss.str();
59 }
60
get_insert_or_replace_statement(std::string_view store_name)61 std::string get_insert_or_replace_statement(std::string_view store_name) {
62 std::stringstream ss;
63 ss << "INSERT OR REPLACE INTO " << store_name
64 << "(ENTRY_KEY, ENTRY_VALUE, ENTRY_ACCESS_COUNT, ENTRY_ACCESS_TIME) VALUES (?, ?, ?, ?)";
65
66 return ss.str();
67 }
68
get_remove_statement(std::string_view store_name)69 std::string get_remove_statement(std::string_view store_name) {
70 std::stringstream ss;
71 ss << "DELETE FROM " << store_name << " WHERE ENTRY_KEY = ?";
72
73 return ss.str();
74 }
75
getQueryStatement(std::string_view store_name)76 std::string getQueryStatement(std::string_view store_name) {
77 std::stringstream ss;
78 ss << "SELECT ENTRY_VALUE, ENTRY_ACCESS_COUNT FROM " << store_name << " WHERE ENTRY_KEY = ?";
79
80 return ss.str();
81 }
82
get_key_count_statement(std::string_view store_name)83 std::string get_key_count_statement(std::string_view store_name) {
84 std::stringstream ss;
85 ss << "SELECT COUNT(*) FROM " << store_name << " WHERE ENTRY_KEY = ?";
86
87 return ss.str();
88 }
89
get_update_entry_access_statement(std::string_view store_name)90 std::string get_update_entry_access_statement(std::string_view store_name) {
91 std::stringstream ss;
92 ss << "UPDATE " << store_name << " SET ENTRY_ACCESS_COUNT = ?, ENTRY_ACCESS_TIME = ? WHERE ENTRY_KEY = ?";
93
94 return ss.str();
95 }
96
to_string(SortOrder order)97 std::string to_string(SortOrder order) {
98 switch (order) {
99 case SortOrder::Ascending: {
100 return "ASC";
101 }
102 case SortOrder::Descending: {
103 return "DESC";
104 }
105 }
106 }
107
108 std::string
get_keys_sorted_by_column_statement(std::string_view storeName,std::string_view columnName,SortOrder order)109 get_keys_sorted_by_column_statement(std::string_view storeName, std::string_view columnName, SortOrder order) {
110 std::stringstream ss;
111 ss << "SELECT ENTRY_KEY, ENTRY_ACCESS_COUNT, ENTRY_ACCESS_TIME FROM " << storeName << " ORDER BY " << columnName
112 << " ";
113 ss << to_string(order);
114
115 return ss.str();
116 }
117
bind_value(PreparedStatement * statement,StorageType type,const Value & value,size_t index,std::error_code & error)118 bool bind_value(
119 PreparedStatement* statement, StorageType type, const Value& value, size_t index, std::error_code& error) {
120 switch (type) {
121 case StorageType::Text: {
122 return statement->bind(index, std::get<std::string>(value), error);
123 }
124 case StorageType::Integer: {
125 return statement->bind(index, std::get<int64_t>(value), error);
126 }
127 case StorageType::Double: {
128 return statement->bind(index, std::get<double>(value), error);
129 }
130 case StorageType::Blob: {
131 return statement->bind(index, std::get<Blob>(value).toUnOwned(), error);
132 }
133 default: {
134 return false;
135 }
136 }
137 }
138
execute(Database * database,const std::string & query,size_t columnIndex,const std::function<bool (const UnOwnedValue &)> & fn,std::error_code & error)139 bool execute(Database* database,
140 const std::string& query,
141 size_t columnIndex,
142 const std::function<bool(const UnOwnedValue&)>& fn,
143 std::error_code& error) {
144 auto statement = database->prepare_statement(query, error);
145 if (!statement) {
146 return false;
147 }
148
149 while (statement->step(error)) {
150 auto columnValue = statement->get_column_value_no_copy(columnIndex, error);
151 if (error || !fn(columnValue)) {
152 break;
153 }
154 }
155
156 return !(error.operator bool());
157 }
158
get_last_access_time(Database * database,std::string_view storeName,std::error_code & error)159 int64_t get_last_access_time(Database* database, std::string_view storeName, std::error_code& error) {
160 int64_t latestAccessTime = 0;
161 auto statement = get_keys_sorted_by_column_statement(storeName, kAccessTimeColumnName, SortOrder::Descending);
162 std::function<bool(const UnOwnedValue&)> fn = [&latestAccessTime](const UnOwnedValue& value) {
163 latestAccessTime = std::get<int64_t>(value);
164 return false;
165 };
166
167 return execute(database, statement, 1, fn, error);
168 }
169
170 } // namespace
171
172 namespace executorchcoreml {
173 namespace sqlite {
174
init(std::error_code & error)175 bool KeyValueStoreImpl::init(std::error_code& error) noexcept {
176 if (!database_->execute(get_create_store_statement(name_, get_key_storage_type_, get_value_storage_type_), error)) {
177 return false;
178 }
179
180 if (!database_->execute(get_create_index_statement(name_, kAccessCountColumnName), error)) {
181 return false;
182 }
183
184 if (!database_->execute(get_create_index_statement(name_, kAccessTimeColumnName), error)) {
185 return false;
186 }
187
188 int64_t lastAccessTime = get_last_access_time(database_.get(), name_, error);
189 if (error) {
190 return false;
191 }
192
193 lastAccessTime_.store(lastAccessTime, std::memory_order_seq_cst);
194 return true;
195 }
196
exists(const Value & key,std::error_code & error)197 bool KeyValueStoreImpl::exists(const Value& key, std::error_code& error) noexcept {
198 if (error) {
199 return false;
200 }
201
202 auto query = database_->prepare_statement(get_key_count_statement(name_), error);
203 if (!query) {
204 return false;
205 }
206
207 if (!bind_value(query.get(), get_key_storage_type(), key, 1, error)) {
208 return false;
209 }
210
211 if (!query->step(error)) {
212 return false;
213 }
214
215 return std::get<int64_t>(query->get_column_value(0, error)) > 0;
216 }
217
updateValueAccessCountAndTime(const Value & key,int64_t accessCount,std::error_code & error)218 bool KeyValueStoreImpl::updateValueAccessCountAndTime(const Value& key,
219 int64_t accessCount,
220 std::error_code& error) noexcept {
221 auto update = database_->prepare_statement(get_update_entry_access_statement(name_), error);
222 if (!update) {
223 return false;
224 }
225
226 if (!bind_value(update.get(), StorageType::Integer, accessCount + 1, 1, error)) {
227 return false;
228 }
229
230 if (!bind_value(update.get(), StorageType::Integer, lastAccessTime_, 2, error)) {
231 return false;
232 }
233
234 if (!bind_value(update.get(), get_key_storage_type(), key, 3, error)) {
235 return false;
236 }
237
238 bool result = update->execute(error);
239 if (result) {
240 lastAccessTime_ += 1;
241 }
242 return result;
243 }
244
get(const Value & key,const std::function<void (const UnOwnedValue &)> & fn,std::error_code & error,bool updateAccessStatistics)245 bool KeyValueStoreImpl::get(const Value& key,
246 const std::function<void(const UnOwnedValue&)>& fn,
247 std::error_code& error,
248 bool updateAccessStatistics) noexcept {
249 auto query = database_->prepare_statement(getQueryStatement(name_), error);
250 if (!query) {
251 return false;
252 }
253
254 if (!bind_value(query.get(), get_key_storage_type(), key, 1, error)) {
255 return false;
256 }
257
258 if (!query->step(error)) {
259 return false;
260 }
261
262 auto value = query->get_column_value_no_copy(0, error);
263 fn(value);
264
265 if (updateAccessStatistics) {
266 int64_t accessCount = std::get<int64_t>(query->get_column_value(1, error));
267 return updateValueAccessCountAndTime(key, accessCount, error);
268 }
269
270 return true;
271 }
272
put(const Value & key,const Value & value,std::error_code & error)273 bool KeyValueStoreImpl::put(const Value& key, const Value& value, std::error_code& error) noexcept {
274 auto statement = database_->prepare_statement(get_insert_or_replace_statement(name_), error);
275 if (!statement) {
276 return false;
277 }
278
279 if (!bind_value(statement.get(), get_key_storage_type(), key, 1, error)) {
280 return false;
281 }
282
283 if (!bind_value(statement.get(), get_value_storage_type(), value, 2, error)) {
284 return false;
285 }
286
287 if (!bind_value(statement.get(), StorageType::Integer, int64_t(1), 3, error)) {
288 return false;
289 }
290
291 if (!bind_value(statement.get(), StorageType::Integer, lastAccessTime_.load(std::memory_order_acquire), 4, error)) {
292 return false;
293 }
294
295 lastAccessTime_ += 1;
296 return statement->execute(error);
297 }
298
remove(const Value & key,std::error_code & error)299 bool KeyValueStoreImpl::remove(const Value& key, std::error_code& error) noexcept {
300 auto statement = database_->prepare_statement(get_remove_statement(name_), error);
301 if (!bind_value(statement.get(), get_key_storage_type(), key, 1, error)) {
302 return false;
303 }
304
305 return statement->execute(error);
306 }
307
get_keys_sorted_by_access_count(const std::function<bool (const UnOwnedValue &)> & fn,SortOrder order,std::error_code & error)308 bool KeyValueStoreImpl::get_keys_sorted_by_access_count(const std::function<bool(const UnOwnedValue&)>& fn,
309 SortOrder order,
310 std::error_code& error) noexcept {
311 auto statement = get_keys_sorted_by_column_statement(name(), kAccessCountColumnName, order);
312 return execute(database_.get(), statement, 0, fn, error);
313 }
314
get_keys_sorted_by_access_time(const std::function<bool (const UnOwnedValue &)> & fn,SortOrder order,std::error_code & error)315 bool KeyValueStoreImpl::get_keys_sorted_by_access_time(const std::function<bool(const UnOwnedValue&)>& fn,
316 SortOrder order,
317 std::error_code& error) noexcept {
318 auto statement = get_keys_sorted_by_column_statement(name(), kAccessTimeColumnName, order);
319 return execute(database_.get(), statement, 0, fn, error);
320 }
321
size(std::error_code & error)322 std::optional<size_t> KeyValueStoreImpl::size(std::error_code& error) noexcept {
323 int64_t count = database_->get_row_count(name_, error);
324 return count < 0 ? std::nullopt : std::optional<size_t>(count);
325 }
326
purge(std::error_code & error)327 bool KeyValueStoreImpl::purge(std::error_code& error) noexcept {
328 if (!database_->drop_table(name_, error)) {
329 return false;
330 }
331
332 return init(error);
333 }
334
335 } // namespace sqlite
336 } // namespace executorchcoreml
337