1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #pragma once
8 #define PY_SSIZE_T_CLEAN
9 #include <Python.h>
10 #include <utility>
11 #include <ostream>
12 #include <memory>
13
14 #define PY_BEGIN try {
15 #define PY_END(v) } catch(mpy::exception_set & err) { return (v); }
16
17 #if PY_VERSION_HEX < 0x03080000
18 #define PY_VECTORCALL _PyObject_FastCallKeywords
19 #else
20 #define PY_VECTORCALL _PyObject_Vectorcall
21 #endif
22
23 struct irange {
24 public:
irangeirange25 irange(int64_t end)
26 : irange(0, end, 1) {}
27 irange(int64_t begin, int64_t end, int64_t step = 1)
begin_irange28 : begin_(begin), end_(end), step_(step) {}
29 int64_t operator*() const {
30 return begin_;
31 }
32 irange& operator++() {
33 begin_ += step_;
34 return *this;
35 }
36 bool operator!=(const irange& other) {
37 return begin_ != other.begin_;
38 }
beginirange39 irange begin() {
40 return *this;
41 }
endirange42 irange end() {
43 return irange {end_, end_, step_};
44 }
45 private:
46 int64_t begin_;
47 int64_t end_;
48 int64_t step_;
49 };
50
51 namespace mpy {
52
53 struct exception_set {
54 };
55
56 struct object;
57 struct vector_args;
58
59 struct handle {
handlehandle60 handle(PyObject* ptr)
61 : ptr_(ptr) {}
62 handle() = default;
63
64
ptrhandle65 PyObject* ptr() const {
66 return ptr_;
67 }
68 object attr(const char* key);
69 bool hasattr(const char* key);
typehandle70 handle type() const {
71 return (PyObject*) Py_TYPE(ptr());
72 }
73
74 template<typename... Args>
75 object call(Args&&... args);
76 object call_object(mpy::handle args);
77 object call_object(mpy::handle args, mpy::handle kwargs);
78 object call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames);
79 object call_vector(vector_args args);
80 bool operator==(handle rhs) {
81 return ptr_ == rhs.ptr_;
82 }
83
checkedhandle84 static handle checked(PyObject* ptr) {
85 if (!ptr) {
86 throw exception_set();
87 }
88 return ptr;
89 }
90
91 protected:
92 PyObject* ptr_ = nullptr;
93 };
94
95
96 template<typename T>
97 struct obj;
98
99 template<typename T>
100 struct hdl : public handle {
ptrhdl101 T* ptr() {
102 return (T*) handle::ptr();
103 }
104 T* operator->() {
105 return ptr();
106 }
hdlhdl107 hdl(T* ptr)
108 : hdl((PyObject*) ptr) {}
hdlhdl109 hdl(const obj<T>& o)
110 : hdl(o.ptr()) {}
111 private:
hdlhdl112 hdl(handle h) : handle(h) {}
113 };
114
115 struct object : public handle {
116 object() = default;
objectobject117 object(const object& other)
118 : handle(other.ptr_) {
119 Py_XINCREF(ptr_);
120 }
objectobject121 object(object&& other) noexcept
122 : handle(other.ptr_) {
123 other.ptr_ = nullptr;
124 }
125 object& operator=(const object& other) {
126 return *this = object(other);
127 }
128 object& operator=(object&& other) noexcept {
129 PyObject* tmp = ptr_;
130 ptr_ = other.ptr_;
131 other.ptr_ = tmp;
132 return *this;
133 }
~objectobject134 ~object() {
135 Py_XDECREF(ptr_);
136 }
stealobject137 static object steal(handle o) {
138 return object(o.ptr());
139 }
checked_stealobject140 static object checked_steal(handle o) {
141 if (!o.ptr()) {
142 throw exception_set();
143 }
144 return steal(o);
145 }
borrowobject146 static object borrow(handle o) {
147 Py_XINCREF(o.ptr());
148 return steal(o);
149 }
releaseobject150 PyObject* release() {
151 auto tmp = ptr_;
152 ptr_ = nullptr;
153 return tmp;
154 }
155 protected:
objectobject156 explicit object(PyObject* ptr)
157 : handle(ptr) {}
158 };
159
160 template<typename T>
161 struct obj : public object {
162 obj() = default;
objobj163 obj(const obj& other)
164 : object(other.ptr_) {
165 Py_XINCREF(ptr_);
166 }
objobj167 obj(obj&& other) noexcept
168 : object(other.ptr_) {
169 other.ptr_ = nullptr;
170 }
171 obj& operator=(const obj& other) {
172 return *this = obj(other);
173 }
174 obj& operator=(obj&& other) noexcept {
175 PyObject* tmp = ptr_;
176 ptr_ = other.ptr_;
177 other.ptr_ = tmp;
178 return *this;
179 }
stealobj180 static obj steal(hdl<T> o) {
181 return obj(o.ptr());
182 }
checked_stealobj183 static obj checked_steal(hdl<T> o) {
184 if (!o.ptr()) {
185 throw exception_set();
186 }
187 return steal(o);
188 }
borrowobj189 static obj borrow(hdl<T> o) {
190 Py_XINCREF(o.ptr());
191 return steal(o);
192 }
ptrobj193 T* ptr() const {
194 return (T*) object::ptr();
195 }
196 T* operator->() {
197 return ptr();
198 }
199 protected:
objobj200 explicit obj(T* ptr)
201 : object((PyObject*)ptr) {}
202 };
203
204
isinstance(handle h,handle c)205 static bool isinstance(handle h, handle c) {
206 return PyObject_IsInstance(h.ptr(), c.ptr());
207 }
208
raise_error(handle exception,const char * format,...)209 [[ noreturn ]] inline void raise_error(handle exception, const char *format, ...) {
210 va_list args;
211 va_start(args, format);
212 PyErr_FormatV(exception.ptr(), format, args);
213 va_end(args);
214 throw exception_set();
215 }
216
217 template<typename T>
218 struct base {
219 PyObject_HEAD
ptrbase220 PyObject* ptr() const {
221 return (PyObject*) this;
222 }
223 static obj<T> alloc(PyTypeObject* type = nullptr) {
224 if (!type) {
225 type = &T::Type;
226 }
227 auto self = (T*) type->tp_alloc(type, 0);
228 if (!self) {
229 throw mpy::exception_set();
230 }
231 new (self) T;
232 return obj<T>::steal(self);
233 }
234 template<typename ... Args>
createbase235 static obj<T> create(Args ... args) {
236 auto self = alloc();
237 self->init(std::forward<Args>(args)...);
238 return self;
239 }
checkbase240 static bool check(handle v) {
241 return isinstance(v, (PyObject*)&T::Type);
242 }
243
unchecked_wrapbase244 static hdl<T> unchecked_wrap(handle self_) {
245 return hdl<T>((T*)self_.ptr());
246 }
wrapbase247 static hdl<T> wrap(handle self_) {
248 if (!check(self_)) {
249 raise_error(PyExc_ValueError, "not an instance of %S", &T::Type);
250 }
251 return unchecked_wrap(self_);
252 }
253
unchecked_wrapbase254 static obj<T> unchecked_wrap(object self_) {
255 return obj<T>::steal(unchecked_wrap(self_.release()));
256 }
wrapbase257 static obj<T> wrap(object self_) {
258 return obj<T>::steal(wrap(self_.release()));
259 }
260
new_stubbase261 static PyObject* new_stub(PyTypeObject *type, PyObject *args, PyObject *kwds) {
262 PY_BEGIN
263 return (PyObject*) alloc(type).release();
264 PY_END(nullptr)
265 }
dealloc_stubbase266 static void dealloc_stub(PyObject *self) {
267 ((T*)self)->~T();
268 Py_TYPE(self)->tp_free(self);
269 }
readybase270 static void ready(mpy::handle mod, const char* name) {
271 if (PyType_Ready(&T::Type)) {
272 throw exception_set();
273 }
274 if(PyModule_AddObject(mod.ptr(), name, (PyObject*) &T::Type) < 0) {
275 throw exception_set();
276 }
277 }
278 };
279
attr(const char * key)280 inline object handle::attr(const char* key) {
281 return object::checked_steal(PyObject_GetAttrString(ptr(), key));
282 }
283
hasattr(const char * key)284 inline bool handle::hasattr(const char* key) {
285 return PyObject_HasAttrString(ptr(), key);
286 }
287
import(const char * module)288 inline object import(const char* module) {
289 return object::checked_steal(PyImport_ImportModule(module));
290 }
291
292 template<typename... Args>
call(Args &&...args)293 inline object handle::call(Args&&... args) {
294 return object::checked_steal(PyObject_CallFunctionObjArgs(ptr_, args.ptr()..., nullptr));
295 }
296
call_object(mpy::handle args)297 inline object handle::call_object(mpy::handle args) {
298 return object::checked_steal(PyObject_CallObject(ptr(), args.ptr()));
299 }
300
301
call_object(mpy::handle args,mpy::handle kwargs)302 inline object handle::call_object(mpy::handle args, mpy::handle kwargs) {
303 return object::checked_steal(PyObject_Call(ptr(), args.ptr(), kwargs.ptr()));
304 }
305
call_vector(mpy::handle * begin,Py_ssize_t nargs,mpy::handle kwnames)306 inline object handle::call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames) {
307 return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) begin, nargs, kwnames.ptr()));
308 }
309
310 struct tuple : public object {
settuple311 void set(int i, object v) {
312 PyTuple_SET_ITEM(ptr_, i, v.release());
313 }
tupletuple314 tuple(int size)
315 : object(checked_steal(PyTuple_New(size))) {}
316 };
317
318 struct list : public object {
setlist319 void set(int i, object v) {
320 PyList_SET_ITEM(ptr_, i, v.release());
321 }
listlist322 list(int size)
323 : object(checked_steal(PyList_New(size))) {}
324 };
325
326 namespace{
unicode_from_format(const char * format,...)327 mpy::object unicode_from_format(const char* format, ...) {
328 va_list args;
329 va_start(args, format);
330 auto r = PyUnicode_FromFormatV(format, args);
331 va_end(args);
332 return mpy::object::checked_steal(r);
333 }
unicode_from_string(const char * str)334 mpy::object unicode_from_string(const char * str) {
335 return mpy::object::checked_steal(PyUnicode_FromString(str));
336 }
337
from_int(Py_ssize_t s)338 mpy::object from_int(Py_ssize_t s) {
339 return mpy::object::checked_steal(PyLong_FromSsize_t(s));
340 }
from_bool(bool b)341 mpy::object from_bool(bool b) {
342 return mpy::object::borrow(b ? Py_True : Py_False);
343 }
344
is_sequence(handle h)345 bool is_sequence(handle h) {
346 return PySequence_Check(h.ptr());
347 }
348 }
349
350 struct sequence_view : public handle {
sequence_viewsequence_view351 sequence_view(handle h)
352 : handle(h) {}
sizesequence_view353 Py_ssize_t size() const {
354 auto r = PySequence_Size(ptr());
355 if (r == -1 && PyErr_Occurred()) {
356 throw mpy::exception_set();
357 }
358 return r;
359 }
enumeratesequence_view360 irange enumerate() const {
361 return irange(size());
362 }
wrapsequence_view363 static sequence_view wrap(handle h) {
364 if (!is_sequence(h)) {
365 raise_error(PyExc_ValueError, "expected a sequence");
366 }
367 return sequence_view(h);
368 }
369 mpy::object operator[](Py_ssize_t i) const {
370 return mpy::object::checked_steal(PySequence_GetItem(ptr(), i));
371 }
372 };
373
374 namespace {
repr(handle h)375 mpy::object repr(handle h) {
376 return mpy::object::checked_steal(PyObject_Repr(h.ptr()));
377 }
378
str(handle h)379 mpy::object str(handle h) {
380 return mpy::object::checked_steal(PyObject_Str(h.ptr()));
381 }
382
383
is_int(handle h)384 bool is_int(handle h) {
385 return PyLong_Check(h.ptr());
386 }
387
is_none(handle h)388 bool is_none(handle h) {
389 return h.ptr() == Py_None;
390 }
391
to_int(handle h)392 Py_ssize_t to_int(handle h) {
393 Py_ssize_t r = PyLong_AsSsize_t(h.ptr());
394 if (r == -1 && PyErr_Occurred()) {
395 throw mpy::exception_set();
396 }
397 return r;
398 }
399
to_bool(handle h)400 bool to_bool(handle h) {
401 return PyObject_IsTrue(h.ptr()) != 0;
402 }
403 }
404
405 struct slice_view {
slice_viewslice_view406 slice_view(handle h, Py_ssize_t size) {
407 if(PySlice_Unpack(h.ptr(), &start, &stop, &step) == -1) {
408 throw mpy::exception_set();
409 }
410 slicelength = PySlice_AdjustIndices(size, &start, &stop, step);
411 }
412 Py_ssize_t start, stop, step, slicelength;
413 };
414
is_slice(handle h)415 static bool is_slice(handle h) {
416 return PySlice_Check(h.ptr());
417 }
418
419 inline std::ostream& operator<<(std::ostream& ss, handle h) {
420 ss << PyUnicode_AsUTF8(str(h).ptr());
421 return ss;
422 }
423
424 struct tuple_view : public handle {
425 tuple_view() = default;
tuple_viewtuple_view426 tuple_view(handle h) : handle(h) {}
427
sizetuple_view428 Py_ssize_t size() const {
429 return PyTuple_GET_SIZE(ptr());
430 }
431
enumeratetuple_view432 irange enumerate() const {
433 return irange(size());
434 }
435
436 handle operator[](Py_ssize_t i) {
437 return PyTuple_GET_ITEM(ptr(), i);
438 }
439
checktuple_view440 static bool check(handle h) {
441 return PyTuple_Check(h.ptr());
442 }
443 };
444
445 struct list_view : public handle {
446 list_view() = default;
list_viewlist_view447 list_view(handle h) : handle(h) {}
sizelist_view448 Py_ssize_t size() const {
449 return PyList_GET_SIZE(ptr());
450 }
451
enumeratelist_view452 irange enumerate() const {
453 return irange(size());
454 }
455
456 handle operator[](Py_ssize_t i) {
457 return PyList_GET_ITEM(ptr(), i);
458 }
459
checklist_view460 static bool check(handle h) {
461 return PyList_Check(h.ptr());
462 }
463 };
464
465 struct dict_view : public handle {
466 dict_view() = default;
dict_viewdict_view467 dict_view(handle h) : handle(h) {}
keysdict_view468 object keys() const {
469 return mpy::object::checked_steal(PyDict_Keys(ptr()));
470 }
valuesdict_view471 object values() const {
472 return mpy::object::checked_steal(PyDict_Values(ptr()));
473 }
itemsdict_view474 object items() const {
475 return mpy::object::checked_steal(PyDict_Items(ptr()));
476 }
containsdict_view477 bool contains(handle k) const {
478 return PyDict_Contains(ptr(), k.ptr());
479 }
480 handle operator[](handle k) {
481 return mpy::handle::checked(PyDict_GetItem(ptr(), k.ptr()));
482 }
checkdict_view483 static bool check(handle h) {
484 return PyDict_Check(h.ptr());
485 }
nextdict_view486 bool next(Py_ssize_t* pos, mpy::handle* key, mpy::handle* value) {
487 PyObject *k = nullptr, *v = nullptr;
488 auto r = PyDict_Next(ptr(), pos, &k, &v);
489 *key = k;
490 *value = v;
491 return r;
492 }
setdict_view493 void set(handle k, handle v) {
494 if (-1 == PyDict_SetItem(ptr(), k.ptr(), v.ptr())) {
495 throw exception_set();
496 }
497 }
498 };
499
500
501 struct kwnames_view : public handle {
502 kwnames_view() = default;
kwnames_viewkwnames_view503 kwnames_view(handle h) : handle(h) {}
504
sizekwnames_view505 Py_ssize_t size() const {
506 return PyTuple_GET_SIZE(ptr());
507 }
508
enumeratekwnames_view509 irange enumerate() const {
510 return irange(size());
511 }
512
513 const char* operator[](Py_ssize_t i) const {
514 PyObject* obj = PyTuple_GET_ITEM(ptr(), i);
515 return PyUnicode_AsUTF8(obj);
516 }
517
checkkwnames_view518 static bool check(handle h) {
519 return PyTuple_Check(h.ptr());
520 }
521 };
522
funcname(mpy::handle func)523 inline mpy::object funcname(mpy::handle func) {
524 if (func.hasattr("__name__")) {
525 return func.attr("__name__");
526 } else {
527 return mpy::str(func);
528 }
529 }
530
531 struct vector_args {
vector_argsvector_args532 vector_args(PyObject *const *a,
533 Py_ssize_t n,
534 PyObject *k)
535 : vector_args((mpy::handle*)a, n, k) {}
vector_argsvector_args536 vector_args(mpy::handle* a,
537 Py_ssize_t n,
538 mpy::handle k)
539 : args((mpy::handle*)a), nargs(n), kwnames(k) {}
540 mpy::handle* args;
541 Py_ssize_t nargs;
542 kwnames_view kwnames;
543
beginvector_args544 mpy::handle* begin() {
545 return args;
546 }
endvector_args547 mpy::handle* end() {
548 return args + size();
549 }
550
551 mpy::handle operator[](int64_t i) const {
552 return args[i];
553 }
has_keywordsvector_args554 bool has_keywords() const {
555 return kwnames.ptr();
556 }
enumerate_positionalvector_args557 irange enumerate_positional() {
558 return irange(nargs);
559 }
enumerate_allvector_args560 irange enumerate_all() {
561 return irange(size());
562 }
sizevector_args563 int64_t size() const {
564 return nargs + (has_keywords() ? kwnames.size() : 0);
565 }
566
567 // bind a test function so this can be tested, first two args for required/kwonly, then return what was parsed...
568
569 // provide write kwarg
570 // don't provide a required arg
571 // don't provide an optional arg
572 // provide a kwarg that is the name of already provided positional
573 // provide a kwonly argument positionally
574 // provide keyword arguments in the wrong order
575 // provide only keyword arguments
576 void parse(const char * fname_cstr, std::initializer_list<const char*> names, std::initializer_list<mpy::handle*> values, int required, int kwonly=0) {
577 auto error = [&]() {
578 // rather than try to match the slower infrastructure with error messages exactly, once we have detected an error, just use that
579 // infrastructure to format it and throw it
580
581 // have to leak this, because python expects these to last
582 const char** names_buf = new const char*[names.size() + 1];
583 std::copy(names.begin(), names.end(), &names_buf[0]);
584 names_buf[names.size()] = nullptr;
585
586 #if PY_VERSION_HEX < 0x03080000
587 char* format_str = new char[names.size() + 3];
588 int i = 0;
589 char* format_it = format_str;
590 for (auto it = names.begin(); it != names.end(); ++it, ++i) {
591 if (i == required) {
592 *format_it++ = '|';
593 }
594 if (i == (int)names.size() - kwonly) {
595 *format_it++ = '$';
596 }
597 *format_it++ = 'O';
598 }
599 *format_it++ = '\0';
600 _PyArg_Parser* _parser = new _PyArg_Parser{format_str, &names_buf[0], fname_cstr, 0};
601 PyObject *dummy = NULL;
602 _PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy);
603 #else
604 _PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0};
605 std::unique_ptr<PyObject*[]> buf(new PyObject*[names.size()]);
606 _PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]);
607 #endif
608 throw exception_set();
609 };
610
611 auto values_it = values.begin();
612 auto names_it = names.begin();
613 auto npositional = values.size() - kwonly;
614
615 if (nargs > (Py_ssize_t)npositional) {
616 // TOO MANY ARGUMENTS
617 error();
618 }
619 for (auto i : irange(nargs)) {
620 *(*values_it++) = args[i];
621 ++names_it;
622 }
623
624 if (!kwnames.ptr()) {
625 if (nargs < required) {
626 // not enough positional arguments
627 error();
628 }
629 } else {
630 int consumed = 0;
631 for (auto i : irange(nargs, values.size())) {
632 bool success = i >= required;
633 const char* target_name = *(names_it++);
634 for (auto j : kwnames.enumerate()) {
635 if (!strcmp(target_name,kwnames[j])) {
636 *(*values_it) = args[nargs + j];
637 ++consumed;
638 success = true;
639 break;
640 }
641 }
642 ++values_it;
643 if (!success) {
644 // REQUIRED ARGUMENT NOT SPECIFIED
645 error();
646 }
647 }
648 if (consumed != kwnames.size()) {
649 // NOT ALL KWNAMES ARGUMENTS WERE USED
650 error();
651 }
652 }
653 }
indexvector_args654 int index(const char* name, int pos) {
655 if (pos < nargs) {
656 return pos;
657 }
658 if (kwnames.ptr()) {
659 for (auto j : kwnames.enumerate()) {
660 if (!strcmp(name, kwnames[j])) {
661 return nargs + j;
662 }
663 }
664 }
665 return -1;
666 }
667 };
668
call_vector(vector_args args)669 inline object handle::call_vector(vector_args args) {
670 return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) args.args, args.nargs, args.kwnames.ptr()));
671 }
672
673
674 }
675
676 #define MPY_ARGS_NAME(typ, name) #name ,
677 #define MPY_ARGS_DECLARE(typ, name) typ name;
678 #define MPY_ARGS_POINTER(typ, name) &name ,
679 #define MPY_PARSE_ARGS_KWARGS(fmt, FORALL_ARGS) \
680 static char* kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
681 FORALL_ARGS(MPY_ARGS_DECLARE) \
682 if (!PyArg_ParseTupleAndKeywords(args, kwargs, fmt, kwlist, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
683 throw mpy::exception_set(); \
684 }
685
686 #define MPY_PARSE_ARGS_KWNAMES(fmt, FORALL_ARGS) \
687 static const char * const kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
688 FORALL_ARGS(MPY_ARGS_DECLARE) \
689 static _PyArg_Parser parser = {fmt, kwlist, 0}; \
690 if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
691 throw mpy::exception_set(); \
692 }
693