xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/kvstore/statement.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 // Statement.cpp
3 //
4 // Copyright © 2024 Apple Inc. All rights reserved.
5 //
6 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
7 
8 #include "statement.hpp"
9 
10 #include <string_view>
11 
12 #include <sqlite_error.hpp>
13 
14 namespace {
15 
16 using namespace executorchcoreml::sqlite;
17 
get_column_count(sqlite3_stmt * stmt)18 size_t get_column_count(sqlite3_stmt* stmt) { return static_cast<size_t>(sqlite3_column_count(stmt)); }
19 
get_column_names(sqlite3_stmt * stmt,size_t column_count)20 std::vector<std::string> get_column_names(sqlite3_stmt* stmt, size_t column_count) {
21     std::vector<std::string> result;
22     result.reserve(static_cast<size_t>(column_count));
23     for (int i = 0; i < column_count; i++) {
24         const char* name = sqlite3_column_name(stmt, i);
25         if (!name) {
26             return {};
27         }
28         result.emplace_back(name);
29     }
30 
31     return result;
32 }
33 
get_parameter_index(sqlite3_stmt * stmt,const std::string & name)34 int get_parameter_index(sqlite3_stmt* stmt, const std::string& name) {
35     return sqlite3_bind_parameter_index(stmt, name.c_str());
36 }
37 
bind_unowned_string(sqlite3_stmt * stmt,size_t index,UnOwnedString value,bool copy,std::error_code & error)38 bool bind_unowned_string(sqlite3_stmt* stmt, size_t index, UnOwnedString value, bool copy, std::error_code& error) {
39     if (error) {
40         return false;
41     }
42     auto destructor = copy ? SQLITE_TRANSIENT : SQLITE_STATIC;
43     const int status =
44         sqlite3_bind_text(stmt, static_cast<int>(index), value.data, static_cast<int>(value.size), destructor);
45     return process_sqlite_status(status, error);
46 }
47 
bind_blob(sqlite3_stmt * stmt,size_t index,const UnOwnedBlob & value,bool copy,std::error_code & error)48 bool bind_blob(sqlite3_stmt* stmt, size_t index, const UnOwnedBlob& value, bool copy, std::error_code& error) {
49     if (error) {
50         return false;
51     }
52     auto destructor = copy ? SQLITE_TRANSIENT : SQLITE_STATIC;
53     const int status =
54         sqlite3_bind_blob(stmt, static_cast<int>(index), value.data, static_cast<int>(value.size), destructor);
55     return process_sqlite_status(status, error);
56 }
57 
get_column_storage_type(sqlite3_stmt * stmt,int index)58 StorageType get_column_storage_type(sqlite3_stmt* stmt, int index) {
59     switch (sqlite3_column_type(stmt, index)) {
60         case SQLITE_INTEGER: {
61             return StorageType::Integer;
62         }
63         case SQLITE_FLOAT: {
64             return StorageType::Double;
65         }
66         case SQLITE_TEXT: {
67             return StorageType::Text;
68         }
69         case SQLITE_BLOB: {
70             return StorageType::Blob;
71         }
72         case SQLITE_NULL: {
73             return StorageType::Null;
74         }
75         default: {
76             return StorageType::Null;
77         }
78     }
79 }
80 
get_column_storage_types(sqlite3_stmt * stmt,size_t columnCount)81 std::vector<StorageType> get_column_storage_types(sqlite3_stmt* stmt, size_t columnCount) {
82     std::vector<StorageType> result;
83     result.reserve(static_cast<size_t>(columnCount));
84     for (int i = 0; i < columnCount; i++) {
85         result.emplace_back(get_column_storage_type(stmt, i));
86     }
87 
88     return result;
89 }
90 
get_int64_value(sqlite3_stmt * stmt,size_t index)91 int64_t get_int64_value(sqlite3_stmt* stmt, size_t index) {
92     return sqlite3_column_int64(stmt, static_cast<int>(index));
93 }
94 
get_double_value(sqlite3_stmt * stmt,size_t index)95 int64_t get_double_value(sqlite3_stmt* stmt, size_t index) {
96     return sqlite3_column_double(stmt, static_cast<int>(index));
97 }
98 
get_string_value(sqlite3_stmt * stmt,size_t index)99 std::string get_string_value(sqlite3_stmt* stmt, size_t index) {
100     auto data = static_cast<const char*>(sqlite3_column_blob(stmt, static_cast<int>(index)));
101     return std::string(data, sqlite3_column_bytes(stmt, static_cast<int>(index)));
102 }
103 
get_unowned_string_value(sqlite3_stmt * stmt,size_t index)104 UnOwnedString get_unowned_string_value(sqlite3_stmt* stmt, size_t index) {
105     auto data = static_cast<const char*>(sqlite3_column_blob(stmt, static_cast<int>(index)));
106     return UnOwnedString(data, sqlite3_column_bytes(stmt, static_cast<int>(index)));
107 }
108 
get_stored_blob_value(sqlite3_stmt * stmt,size_t index)109 std::pair<const void*, size_t> get_stored_blob_value(sqlite3_stmt* stmt, size_t index) {
110     const void* data = sqlite3_column_blob(stmt, static_cast<int>(index));
111     int n = sqlite3_column_bytes(stmt, static_cast<int>(index));
112     return { data, static_cast<size_t>(n) };
113 }
114 
get_blob_value(sqlite3_stmt * stmt,size_t index)115 Blob get_blob_value(sqlite3_stmt* stmt, size_t index) {
116     const auto& pair = get_stored_blob_value(stmt, index);
117     return Blob(pair.first, pair.second);
118 }
119 
get_unowned_blob_value(sqlite3_stmt * stmt,size_t index)120 UnOwnedBlob get_unowned_blob_value(sqlite3_stmt* stmt, size_t index) {
121     const auto& pair = get_stored_blob_value(stmt, index);
122     return UnOwnedBlob(pair.first, pair.second);
123 }
124 } // namespace
125 
126 namespace executorchcoreml {
127 namespace sqlite {
128 
PreparedStatement(std::unique_ptr<sqlite3_stmt,StatementDeleter> prepared_statement)129 PreparedStatement::PreparedStatement(std::unique_ptr<sqlite3_stmt, StatementDeleter> prepared_statement) noexcept
130     : column_count_(::get_column_count(prepared_statement.get())),
131       column_names_(::get_column_names(prepared_statement.get(), column_count_)),
132       prepared_statement_(std::move(prepared_statement)) { }
133 
bind(size_t index,int64_t value,std::error_code & error) const134 bool PreparedStatement::bind(size_t index, int64_t value, std::error_code& error) const noexcept {
135     if (error) {
136         return false;
137     }
138     const int status = sqlite3_bind_int64(get_underlying_statement(), static_cast<int>(index), value);
139     return process_sqlite_status(status, error);
140 }
141 
bind_name(const std::string & name,int64_t value,std::error_code & error) const142 bool PreparedStatement::bind_name(const std::string& name, int64_t value, std::error_code& error) const noexcept {
143     return bind(get_parameter_index(get_underlying_statement(), name), value, error);
144 }
145 
bind(size_t index,double value,std::error_code & error) const146 bool PreparedStatement::bind(size_t index, double value, std::error_code& error) const noexcept {
147     if (error) {
148         return false;
149     }
150     const int status = sqlite3_bind_double(get_underlying_statement(), static_cast<int>(index), value);
151     return process_sqlite_status(status, error);
152 }
153 
bind(size_t index,UnOwnedString value,std::error_code & error) const154 bool PreparedStatement::bind(size_t index, UnOwnedString value, std::error_code& error) const noexcept {
155     return bind_unowned_string(get_underlying_statement(), index, value, true, error);
156 }
157 
bind_no_copy(size_t index,UnOwnedString value,std::error_code & error) const158 bool PreparedStatement::bind_no_copy(size_t index, UnOwnedString value, std::error_code& error) const noexcept {
159     return bind_unowned_string(get_underlying_statement(), index, value, false, error);
160 }
161 
bind_name(const std::string & name,UnOwnedString value,std::error_code & error) const162 bool PreparedStatement::bind_name(const std::string& name, UnOwnedString value, std::error_code& error) const noexcept {
163     size_t index = get_parameter_index(get_underlying_statement(), name);
164     return bind_unowned_string(get_underlying_statement(), index, value, true, error);
165 }
166 
bind_name_no_copy(const std::string & name,UnOwnedString value,std::error_code & error) const167 bool PreparedStatement::bind_name_no_copy(const std::string& name,
168                                           UnOwnedString value,
169                                           std::error_code& error) const noexcept {
170     size_t index = get_parameter_index(get_underlying_statement(), name);
171     return bind_unowned_string(get_underlying_statement(), index, value, false, error);
172 }
173 
bind(size_t index,const UnOwnedBlob & value,std::error_code & error) const174 bool PreparedStatement::bind(size_t index, const UnOwnedBlob& value, std::error_code& error) const noexcept {
175     return bind_blob(get_underlying_statement(), index, value, true, error);
176 }
177 
bind_name(const std::string & name,const UnOwnedBlob & value,std::error_code & error) const178 bool PreparedStatement::bind_name(const std::string& name,
179                                   const UnOwnedBlob& value,
180                                   std::error_code& error) const noexcept {
181     size_t index = get_parameter_index(get_underlying_statement(), name);
182     return bind_blob(get_underlying_statement(), index, value, true, error);
183 }
184 
bind_no_copy(size_t index,const UnOwnedBlob & value,std::error_code & error) const185 bool PreparedStatement::bind_no_copy(size_t index, const UnOwnedBlob& value, std::error_code& error) const noexcept {
186     return bind_blob(get_underlying_statement(), index, value, false, error);
187 }
188 
bind_name_no_copy(const std::string & name,const UnOwnedBlob & value,std::error_code & error) const189 bool PreparedStatement::bind_name_no_copy(const std::string& name,
190                                           const UnOwnedBlob& value,
191                                           std::error_code& error) const noexcept {
192     size_t index = get_parameter_index(get_underlying_statement(), name);
193     return bind_blob(get_underlying_statement(), index, value, false, error);
194 }
195 
reset(std::error_code & error) const196 bool PreparedStatement::reset(std::error_code& error) const noexcept {
197     if (error) {
198         return false;
199     }
200     const int status = sqlite3_reset(get_underlying_statement());
201     return process_sqlite_status(status, error);
202 }
203 
step(std::error_code & error) const204 bool PreparedStatement::step(std::error_code& error) const noexcept {
205     if (error) {
206         return false;
207     }
208     const int status = sqlite3_step(get_underlying_statement());
209     if (status == SQLITE_ROW) {
210         return true;
211     } else if (status == SQLITE_DONE) {
212         return false;
213     } else {
214         return process_sqlite_status(status, error);
215     }
216 }
217 
get_column_storage_types()218 const std::vector<StorageType>& PreparedStatement::get_column_storage_types() noexcept {
219     if (column_storage_types_.empty()) {
220         column_storage_types_ = ::get_column_storage_types(get_underlying_statement(), get_column_count());
221     }
222     return column_storage_types_;
223 }
224 
get_column_value(size_t index,std::error_code & error)225 Value PreparedStatement::get_column_value(size_t index, std::error_code& error) noexcept {
226     if (error) {
227         return Null();
228     }
229     switch (get_column_storage_type(index)) {
230         case StorageType::Integer: {
231             return get_int64_value(get_underlying_statement(), index);
232         }
233         case StorageType::Double: {
234             return get_double_value(get_underlying_statement(), index);
235         }
236         case StorageType::Text: {
237             return get_string_value(get_underlying_statement(), index);
238         }
239         case StorageType::Blob: {
240             return get_blob_value(get_underlying_statement(), index);
241         }
242         case StorageType::Null: {
243             return Null();
244         }
245     }
246 }
247 
get_column_value_no_copy(size_t index,std::error_code & error)248 UnOwnedValue PreparedStatement::get_column_value_no_copy(size_t index, std::error_code& error) noexcept {
249     if (error) {
250         return Null();
251     }
252     switch (get_column_storage_type(index)) {
253         case StorageType::Integer: {
254             return get_int64_value(get_underlying_statement(), index);
255         }
256         case StorageType::Double: {
257             return get_double_value(get_underlying_statement(), index);
258         }
259         case StorageType::Text: {
260             return get_unowned_string_value(get_underlying_statement(), index);
261         }
262         case StorageType::Blob: {
263             return get_unowned_blob_value(get_underlying_statement(), index);
264         }
265         case StorageType::Null: {
266             return Null();
267         }
268     }
269 }
270 
execute(std::error_code & error) const271 bool PreparedStatement::execute(std::error_code& error) const noexcept {
272     if (error) {
273         return false;
274     }
275     const int status = sqlite3_step(get_underlying_statement());
276     if (status == SQLITE_DONE) {
277         return true;
278     } else {
279         return process_sqlite_status(status, error);
280     }
281 }
282 
283 } // namespace sqlite
284 } // namespace executorchcoreml
285