xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/kvstore/key_value_store.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 // KeyValueStore.cpp
3 //
4 // Copyright © 2024 Apple Inc. All rights reserved.
5 //
6 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
7 
8 #include "key_value_store.hpp"
9 
10 #include <iostream>
11 #include <sstream>
12 
13 namespace {
14 using namespace executorchcoreml::sqlite;
15 
16 constexpr std::string_view kAccessCountColumnName = "ENTRY_ACCESS_COUNT";
17 constexpr std::string_view kAccessTimeColumnName = "ENTRY_ACCESS_TIME";
18 
to_string(StorageType storage_type)19 std::string to_string(StorageType storage_type) {
20     switch (storage_type) {
21         case StorageType::Text: {
22             return "TEXT";
23         }
24         case StorageType::Integer: {
25             return "INTEGER";
26         }
27         case StorageType::Double: {
28             return "REAL";
29         }
30         case StorageType::Blob: {
31             return "BLOB";
32         }
33         case StorageType::Null: {
34             return "NULL";
35         }
36     }
37 }
38 
39 std::string
get_create_store_statement(std::string_view store_name,StorageType key_storage_type,StorageType value_storage_type)40 get_create_store_statement(std::string_view store_name, StorageType key_storage_type, StorageType value_storage_type) {
41     std::stringstream ss;
42     ss << "CREATE TABLE IF NOT EXISTS ";
43     ss << store_name << " ";
44     ss << "(";
45     ss << "ENTRY_KEY " << to_string(key_storage_type) << "PRIMARY KEY UNIQUE, ";
46     ss << "ENTRY_VALUE " << to_string(value_storage_type) << ", ";
47     ss << "ENTRY_ACCESS_COUNT " << to_string(StorageType::Integer) << ", ";
48     ss << "ENTRY_ACCESS_TIME " << to_string(StorageType::Integer);
49     ss << ")";
50 
51     return ss.str();
52 }
53 
get_create_index_statement(std::string_view store_name,std::string_view column_name)54 std::string get_create_index_statement(std::string_view store_name, std::string_view column_name) {
55     std::stringstream ss;
56     ss << "CREATE INDEX IF NOT EXISTS " << column_name << "_INDEX" << " ON " << store_name << "(" << column_name << ")";
57 
58     return ss.str();
59 }
60 
get_insert_or_replace_statement(std::string_view store_name)61 std::string get_insert_or_replace_statement(std::string_view store_name) {
62     std::stringstream ss;
63     ss << "INSERT OR REPLACE INTO " << store_name
64        << "(ENTRY_KEY, ENTRY_VALUE, ENTRY_ACCESS_COUNT, ENTRY_ACCESS_TIME) VALUES (?, ?, ?, ?)";
65 
66     return ss.str();
67 }
68 
get_remove_statement(std::string_view store_name)69 std::string get_remove_statement(std::string_view store_name) {
70     std::stringstream ss;
71     ss << "DELETE FROM " << store_name << " WHERE ENTRY_KEY = ?";
72 
73     return ss.str();
74 }
75 
getQueryStatement(std::string_view store_name)76 std::string getQueryStatement(std::string_view store_name) {
77     std::stringstream ss;
78     ss << "SELECT ENTRY_VALUE, ENTRY_ACCESS_COUNT FROM " << store_name << " WHERE ENTRY_KEY = ?";
79 
80     return ss.str();
81 }
82 
get_key_count_statement(std::string_view store_name)83 std::string get_key_count_statement(std::string_view store_name) {
84     std::stringstream ss;
85     ss << "SELECT COUNT(*) FROM " << store_name << " WHERE ENTRY_KEY = ?";
86 
87     return ss.str();
88 }
89 
get_update_entry_access_statement(std::string_view store_name)90 std::string get_update_entry_access_statement(std::string_view store_name) {
91     std::stringstream ss;
92     ss << "UPDATE " << store_name << " SET ENTRY_ACCESS_COUNT = ?, ENTRY_ACCESS_TIME = ? WHERE ENTRY_KEY = ?";
93 
94     return ss.str();
95 }
96 
to_string(SortOrder order)97 std::string to_string(SortOrder order) {
98     switch (order) {
99         case SortOrder::Ascending: {
100             return "ASC";
101         }
102         case SortOrder::Descending: {
103             return "DESC";
104         }
105     }
106 }
107 
108 std::string
get_keys_sorted_by_column_statement(std::string_view storeName,std::string_view columnName,SortOrder order)109 get_keys_sorted_by_column_statement(std::string_view storeName, std::string_view columnName, SortOrder order) {
110     std::stringstream ss;
111     ss << "SELECT ENTRY_KEY, ENTRY_ACCESS_COUNT, ENTRY_ACCESS_TIME FROM " << storeName << " ORDER BY " << columnName
112        << " ";
113     ss << to_string(order);
114 
115     return ss.str();
116 }
117 
bind_value(PreparedStatement * statement,StorageType type,const Value & value,size_t index,std::error_code & error)118 bool bind_value(
119     PreparedStatement* statement, StorageType type, const Value& value, size_t index, std::error_code& error) {
120     switch (type) {
121         case StorageType::Text: {
122             return statement->bind(index, std::get<std::string>(value), error);
123         }
124         case StorageType::Integer: {
125             return statement->bind(index, std::get<int64_t>(value), error);
126         }
127         case StorageType::Double: {
128             return statement->bind(index, std::get<double>(value), error);
129         }
130         case StorageType::Blob: {
131             return statement->bind(index, std::get<Blob>(value).toUnOwned(), error);
132         }
133         default: {
134             return false;
135         }
136     }
137 }
138 
execute(Database * database,const std::string & query,size_t columnIndex,const std::function<bool (const UnOwnedValue &)> & fn,std::error_code & error)139 bool execute(Database* database,
140              const std::string& query,
141              size_t columnIndex,
142              const std::function<bool(const UnOwnedValue&)>& fn,
143              std::error_code& error) {
144     auto statement = database->prepare_statement(query, error);
145     if (!statement) {
146         return false;
147     }
148 
149     while (statement->step(error)) {
150         auto columnValue = statement->get_column_value_no_copy(columnIndex, error);
151         if (error || !fn(columnValue)) {
152             break;
153         }
154     }
155 
156     return !(error.operator bool());
157 }
158 
get_last_access_time(Database * database,std::string_view storeName,std::error_code & error)159 int64_t get_last_access_time(Database* database, std::string_view storeName, std::error_code& error) {
160     int64_t latestAccessTime = 0;
161     auto statement = get_keys_sorted_by_column_statement(storeName, kAccessTimeColumnName, SortOrder::Descending);
162     std::function<bool(const UnOwnedValue&)> fn = [&latestAccessTime](const UnOwnedValue& value) {
163         latestAccessTime = std::get<int64_t>(value);
164         return false;
165     };
166 
167     return execute(database, statement, 1, fn, error);
168 }
169 
170 } // namespace
171 
172 namespace executorchcoreml {
173 namespace sqlite {
174 
init(std::error_code & error)175 bool KeyValueStoreImpl::init(std::error_code& error) noexcept {
176     if (!database_->execute(get_create_store_statement(name_, get_key_storage_type_, get_value_storage_type_), error)) {
177         return false;
178     }
179 
180     if (!database_->execute(get_create_index_statement(name_, kAccessCountColumnName), error)) {
181         return false;
182     }
183 
184     if (!database_->execute(get_create_index_statement(name_, kAccessTimeColumnName), error)) {
185         return false;
186     }
187 
188     int64_t lastAccessTime = get_last_access_time(database_.get(), name_, error);
189     if (error) {
190         return false;
191     }
192 
193     lastAccessTime_.store(lastAccessTime, std::memory_order_seq_cst);
194     return true;
195 }
196 
exists(const Value & key,std::error_code & error)197 bool KeyValueStoreImpl::exists(const Value& key, std::error_code& error) noexcept {
198     if (error) {
199         return false;
200     }
201 
202     auto query = database_->prepare_statement(get_key_count_statement(name_), error);
203     if (!query) {
204         return false;
205     }
206 
207     if (!bind_value(query.get(), get_key_storage_type(), key, 1, error)) {
208         return false;
209     }
210 
211     if (!query->step(error)) {
212         return false;
213     }
214 
215     return std::get<int64_t>(query->get_column_value(0, error)) > 0;
216 }
217 
updateValueAccessCountAndTime(const Value & key,int64_t accessCount,std::error_code & error)218 bool KeyValueStoreImpl::updateValueAccessCountAndTime(const Value& key,
219                                                       int64_t accessCount,
220                                                       std::error_code& error) noexcept {
221     auto update = database_->prepare_statement(get_update_entry_access_statement(name_), error);
222     if (!update) {
223         return false;
224     }
225 
226     if (!bind_value(update.get(), StorageType::Integer, accessCount + 1, 1, error)) {
227         return false;
228     }
229 
230     if (!bind_value(update.get(), StorageType::Integer, lastAccessTime_, 2, error)) {
231         return false;
232     }
233 
234     if (!bind_value(update.get(), get_key_storage_type(), key, 3, error)) {
235         return false;
236     }
237 
238     bool result = update->execute(error);
239     if (result) {
240         lastAccessTime_ += 1;
241     }
242     return result;
243 }
244 
get(const Value & key,const std::function<void (const UnOwnedValue &)> & fn,std::error_code & error,bool updateAccessStatistics)245 bool KeyValueStoreImpl::get(const Value& key,
246                             const std::function<void(const UnOwnedValue&)>& fn,
247                             std::error_code& error,
248                             bool updateAccessStatistics) noexcept {
249     auto query = database_->prepare_statement(getQueryStatement(name_), error);
250     if (!query) {
251         return false;
252     }
253 
254     if (!bind_value(query.get(), get_key_storage_type(), key, 1, error)) {
255         return false;
256     }
257 
258     if (!query->step(error)) {
259         return false;
260     }
261 
262     auto value = query->get_column_value_no_copy(0, error);
263     fn(value);
264 
265     if (updateAccessStatistics) {
266         int64_t accessCount = std::get<int64_t>(query->get_column_value(1, error));
267         return updateValueAccessCountAndTime(key, accessCount, error);
268     }
269 
270     return true;
271 }
272 
put(const Value & key,const Value & value,std::error_code & error)273 bool KeyValueStoreImpl::put(const Value& key, const Value& value, std::error_code& error) noexcept {
274     auto statement = database_->prepare_statement(get_insert_or_replace_statement(name_), error);
275     if (!statement) {
276         return false;
277     }
278 
279     if (!bind_value(statement.get(), get_key_storage_type(), key, 1, error)) {
280         return false;
281     }
282 
283     if (!bind_value(statement.get(), get_value_storage_type(), value, 2, error)) {
284         return false;
285     }
286 
287     if (!bind_value(statement.get(), StorageType::Integer, int64_t(1), 3, error)) {
288         return false;
289     }
290 
291     if (!bind_value(statement.get(), StorageType::Integer, lastAccessTime_.load(std::memory_order_acquire), 4, error)) {
292         return false;
293     }
294 
295     lastAccessTime_ += 1;
296     return statement->execute(error);
297 }
298 
remove(const Value & key,std::error_code & error)299 bool KeyValueStoreImpl::remove(const Value& key, std::error_code& error) noexcept {
300     auto statement = database_->prepare_statement(get_remove_statement(name_), error);
301     if (!bind_value(statement.get(), get_key_storage_type(), key, 1, error)) {
302         return false;
303     }
304 
305     return statement->execute(error);
306 }
307 
get_keys_sorted_by_access_count(const std::function<bool (const UnOwnedValue &)> & fn,SortOrder order,std::error_code & error)308 bool KeyValueStoreImpl::get_keys_sorted_by_access_count(const std::function<bool(const UnOwnedValue&)>& fn,
309                                                         SortOrder order,
310                                                         std::error_code& error) noexcept {
311     auto statement = get_keys_sorted_by_column_statement(name(), kAccessCountColumnName, order);
312     return execute(database_.get(), statement, 0, fn, error);
313 }
314 
get_keys_sorted_by_access_time(const std::function<bool (const UnOwnedValue &)> & fn,SortOrder order,std::error_code & error)315 bool KeyValueStoreImpl::get_keys_sorted_by_access_time(const std::function<bool(const UnOwnedValue&)>& fn,
316                                                        SortOrder order,
317                                                        std::error_code& error) noexcept {
318     auto statement = get_keys_sorted_by_column_statement(name(), kAccessTimeColumnName, order);
319     return execute(database_.get(), statement, 0, fn, error);
320 }
321 
size(std::error_code & error)322 std::optional<size_t> KeyValueStoreImpl::size(std::error_code& error) noexcept {
323     int64_t count = database_->get_row_count(name_, error);
324     return count < 0 ? std::nullopt : std::optional<size_t>(count);
325 }
326 
purge(std::error_code & error)327 bool KeyValueStoreImpl::purge(std::error_code& error) noexcept {
328     if (!database_->drop_table(name_, error)) {
329         return false;
330     }
331 
332     return init(error);
333 }
334 
335 } // namespace sqlite
336 } // namespace executorchcoreml
337