1 //
2 // Statement.cpp
3 //
4 // Copyright © 2024 Apple Inc. All rights reserved.
5 //
6 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8 #include "statement.hpp"
9
10 #include <string_view>
11
12 #include <sqlite_error.hpp>
13
14 namespace {
15
16 using namespace executorchcoreml::sqlite;
17
get_column_count(sqlite3_stmt * stmt)18 size_t get_column_count(sqlite3_stmt* stmt) { return static_cast<size_t>(sqlite3_column_count(stmt)); }
19
get_column_names(sqlite3_stmt * stmt,size_t column_count)20 std::vector<std::string> get_column_names(sqlite3_stmt* stmt, size_t column_count) {
21 std::vector<std::string> result;
22 result.reserve(static_cast<size_t>(column_count));
23 for (int i = 0; i < column_count; i++) {
24 const char* name = sqlite3_column_name(stmt, i);
25 if (!name) {
26 return {};
27 }
28 result.emplace_back(name);
29 }
30
31 return result;
32 }
33
get_parameter_index(sqlite3_stmt * stmt,const std::string & name)34 int get_parameter_index(sqlite3_stmt* stmt, const std::string& name) {
35 return sqlite3_bind_parameter_index(stmt, name.c_str());
36 }
37
bind_unowned_string(sqlite3_stmt * stmt,size_t index,UnOwnedString value,bool copy,std::error_code & error)38 bool bind_unowned_string(sqlite3_stmt* stmt, size_t index, UnOwnedString value, bool copy, std::error_code& error) {
39 if (error) {
40 return false;
41 }
42 auto destructor = copy ? SQLITE_TRANSIENT : SQLITE_STATIC;
43 const int status =
44 sqlite3_bind_text(stmt, static_cast<int>(index), value.data, static_cast<int>(value.size), destructor);
45 return process_sqlite_status(status, error);
46 }
47
bind_blob(sqlite3_stmt * stmt,size_t index,const UnOwnedBlob & value,bool copy,std::error_code & error)48 bool bind_blob(sqlite3_stmt* stmt, size_t index, const UnOwnedBlob& value, bool copy, std::error_code& error) {
49 if (error) {
50 return false;
51 }
52 auto destructor = copy ? SQLITE_TRANSIENT : SQLITE_STATIC;
53 const int status =
54 sqlite3_bind_blob(stmt, static_cast<int>(index), value.data, static_cast<int>(value.size), destructor);
55 return process_sqlite_status(status, error);
56 }
57
get_column_storage_type(sqlite3_stmt * stmt,int index)58 StorageType get_column_storage_type(sqlite3_stmt* stmt, int index) {
59 switch (sqlite3_column_type(stmt, index)) {
60 case SQLITE_INTEGER: {
61 return StorageType::Integer;
62 }
63 case SQLITE_FLOAT: {
64 return StorageType::Double;
65 }
66 case SQLITE_TEXT: {
67 return StorageType::Text;
68 }
69 case SQLITE_BLOB: {
70 return StorageType::Blob;
71 }
72 case SQLITE_NULL: {
73 return StorageType::Null;
74 }
75 default: {
76 return StorageType::Null;
77 }
78 }
79 }
80
get_column_storage_types(sqlite3_stmt * stmt,size_t columnCount)81 std::vector<StorageType> get_column_storage_types(sqlite3_stmt* stmt, size_t columnCount) {
82 std::vector<StorageType> result;
83 result.reserve(static_cast<size_t>(columnCount));
84 for (int i = 0; i < columnCount; i++) {
85 result.emplace_back(get_column_storage_type(stmt, i));
86 }
87
88 return result;
89 }
90
get_int64_value(sqlite3_stmt * stmt,size_t index)91 int64_t get_int64_value(sqlite3_stmt* stmt, size_t index) {
92 return sqlite3_column_int64(stmt, static_cast<int>(index));
93 }
94
get_double_value(sqlite3_stmt * stmt,size_t index)95 int64_t get_double_value(sqlite3_stmt* stmt, size_t index) {
96 return sqlite3_column_double(stmt, static_cast<int>(index));
97 }
98
get_string_value(sqlite3_stmt * stmt,size_t index)99 std::string get_string_value(sqlite3_stmt* stmt, size_t index) {
100 auto data = static_cast<const char*>(sqlite3_column_blob(stmt, static_cast<int>(index)));
101 return std::string(data, sqlite3_column_bytes(stmt, static_cast<int>(index)));
102 }
103
get_unowned_string_value(sqlite3_stmt * stmt,size_t index)104 UnOwnedString get_unowned_string_value(sqlite3_stmt* stmt, size_t index) {
105 auto data = static_cast<const char*>(sqlite3_column_blob(stmt, static_cast<int>(index)));
106 return UnOwnedString(data, sqlite3_column_bytes(stmt, static_cast<int>(index)));
107 }
108
get_stored_blob_value(sqlite3_stmt * stmt,size_t index)109 std::pair<const void*, size_t> get_stored_blob_value(sqlite3_stmt* stmt, size_t index) {
110 const void* data = sqlite3_column_blob(stmt, static_cast<int>(index));
111 int n = sqlite3_column_bytes(stmt, static_cast<int>(index));
112 return { data, static_cast<size_t>(n) };
113 }
114
get_blob_value(sqlite3_stmt * stmt,size_t index)115 Blob get_blob_value(sqlite3_stmt* stmt, size_t index) {
116 const auto& pair = get_stored_blob_value(stmt, index);
117 return Blob(pair.first, pair.second);
118 }
119
get_unowned_blob_value(sqlite3_stmt * stmt,size_t index)120 UnOwnedBlob get_unowned_blob_value(sqlite3_stmt* stmt, size_t index) {
121 const auto& pair = get_stored_blob_value(stmt, index);
122 return UnOwnedBlob(pair.first, pair.second);
123 }
124 } // namespace
125
126 namespace executorchcoreml {
127 namespace sqlite {
128
PreparedStatement(std::unique_ptr<sqlite3_stmt,StatementDeleter> prepared_statement)129 PreparedStatement::PreparedStatement(std::unique_ptr<sqlite3_stmt, StatementDeleter> prepared_statement) noexcept
130 : column_count_(::get_column_count(prepared_statement.get())),
131 column_names_(::get_column_names(prepared_statement.get(), column_count_)),
132 prepared_statement_(std::move(prepared_statement)) { }
133
bind(size_t index,int64_t value,std::error_code & error) const134 bool PreparedStatement::bind(size_t index, int64_t value, std::error_code& error) const noexcept {
135 if (error) {
136 return false;
137 }
138 const int status = sqlite3_bind_int64(get_underlying_statement(), static_cast<int>(index), value);
139 return process_sqlite_status(status, error);
140 }
141
bind_name(const std::string & name,int64_t value,std::error_code & error) const142 bool PreparedStatement::bind_name(const std::string& name, int64_t value, std::error_code& error) const noexcept {
143 return bind(get_parameter_index(get_underlying_statement(), name), value, error);
144 }
145
bind(size_t index,double value,std::error_code & error) const146 bool PreparedStatement::bind(size_t index, double value, std::error_code& error) const noexcept {
147 if (error) {
148 return false;
149 }
150 const int status = sqlite3_bind_double(get_underlying_statement(), static_cast<int>(index), value);
151 return process_sqlite_status(status, error);
152 }
153
bind(size_t index,UnOwnedString value,std::error_code & error) const154 bool PreparedStatement::bind(size_t index, UnOwnedString value, std::error_code& error) const noexcept {
155 return bind_unowned_string(get_underlying_statement(), index, value, true, error);
156 }
157
bind_no_copy(size_t index,UnOwnedString value,std::error_code & error) const158 bool PreparedStatement::bind_no_copy(size_t index, UnOwnedString value, std::error_code& error) const noexcept {
159 return bind_unowned_string(get_underlying_statement(), index, value, false, error);
160 }
161
bind_name(const std::string & name,UnOwnedString value,std::error_code & error) const162 bool PreparedStatement::bind_name(const std::string& name, UnOwnedString value, std::error_code& error) const noexcept {
163 size_t index = get_parameter_index(get_underlying_statement(), name);
164 return bind_unowned_string(get_underlying_statement(), index, value, true, error);
165 }
166
bind_name_no_copy(const std::string & name,UnOwnedString value,std::error_code & error) const167 bool PreparedStatement::bind_name_no_copy(const std::string& name,
168 UnOwnedString value,
169 std::error_code& error) const noexcept {
170 size_t index = get_parameter_index(get_underlying_statement(), name);
171 return bind_unowned_string(get_underlying_statement(), index, value, false, error);
172 }
173
bind(size_t index,const UnOwnedBlob & value,std::error_code & error) const174 bool PreparedStatement::bind(size_t index, const UnOwnedBlob& value, std::error_code& error) const noexcept {
175 return bind_blob(get_underlying_statement(), index, value, true, error);
176 }
177
bind_name(const std::string & name,const UnOwnedBlob & value,std::error_code & error) const178 bool PreparedStatement::bind_name(const std::string& name,
179 const UnOwnedBlob& value,
180 std::error_code& error) const noexcept {
181 size_t index = get_parameter_index(get_underlying_statement(), name);
182 return bind_blob(get_underlying_statement(), index, value, true, error);
183 }
184
bind_no_copy(size_t index,const UnOwnedBlob & value,std::error_code & error) const185 bool PreparedStatement::bind_no_copy(size_t index, const UnOwnedBlob& value, std::error_code& error) const noexcept {
186 return bind_blob(get_underlying_statement(), index, value, false, error);
187 }
188
bind_name_no_copy(const std::string & name,const UnOwnedBlob & value,std::error_code & error) const189 bool PreparedStatement::bind_name_no_copy(const std::string& name,
190 const UnOwnedBlob& value,
191 std::error_code& error) const noexcept {
192 size_t index = get_parameter_index(get_underlying_statement(), name);
193 return bind_blob(get_underlying_statement(), index, value, false, error);
194 }
195
reset(std::error_code & error) const196 bool PreparedStatement::reset(std::error_code& error) const noexcept {
197 if (error) {
198 return false;
199 }
200 const int status = sqlite3_reset(get_underlying_statement());
201 return process_sqlite_status(status, error);
202 }
203
step(std::error_code & error) const204 bool PreparedStatement::step(std::error_code& error) const noexcept {
205 if (error) {
206 return false;
207 }
208 const int status = sqlite3_step(get_underlying_statement());
209 if (status == SQLITE_ROW) {
210 return true;
211 } else if (status == SQLITE_DONE) {
212 return false;
213 } else {
214 return process_sqlite_status(status, error);
215 }
216 }
217
get_column_storage_types()218 const std::vector<StorageType>& PreparedStatement::get_column_storage_types() noexcept {
219 if (column_storage_types_.empty()) {
220 column_storage_types_ = ::get_column_storage_types(get_underlying_statement(), get_column_count());
221 }
222 return column_storage_types_;
223 }
224
get_column_value(size_t index,std::error_code & error)225 Value PreparedStatement::get_column_value(size_t index, std::error_code& error) noexcept {
226 if (error) {
227 return Null();
228 }
229 switch (get_column_storage_type(index)) {
230 case StorageType::Integer: {
231 return get_int64_value(get_underlying_statement(), index);
232 }
233 case StorageType::Double: {
234 return get_double_value(get_underlying_statement(), index);
235 }
236 case StorageType::Text: {
237 return get_string_value(get_underlying_statement(), index);
238 }
239 case StorageType::Blob: {
240 return get_blob_value(get_underlying_statement(), index);
241 }
242 case StorageType::Null: {
243 return Null();
244 }
245 }
246 }
247
get_column_value_no_copy(size_t index,std::error_code & error)248 UnOwnedValue PreparedStatement::get_column_value_no_copy(size_t index, std::error_code& error) noexcept {
249 if (error) {
250 return Null();
251 }
252 switch (get_column_storage_type(index)) {
253 case StorageType::Integer: {
254 return get_int64_value(get_underlying_statement(), index);
255 }
256 case StorageType::Double: {
257 return get_double_value(get_underlying_statement(), index);
258 }
259 case StorageType::Text: {
260 return get_unowned_string_value(get_underlying_statement(), index);
261 }
262 case StorageType::Blob: {
263 return get_unowned_blob_value(get_underlying_statement(), index);
264 }
265 case StorageType::Null: {
266 return Null();
267 }
268 }
269 }
270
execute(std::error_code & error) const271 bool PreparedStatement::execute(std::error_code& error) const noexcept {
272 if (error) {
273 return false;
274 }
275 const int status = sqlite3_step(get_underlying_statement());
276 if (status == SQLITE_DONE) {
277 return true;
278 } else {
279 return process_sqlite_status(status, error);
280 }
281 }
282
283 } // namespace sqlite
284 } // namespace executorchcoreml
285