xref: /aosp_15_r20/external/pytorch/functorch/csrc/dim/minpybind.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 #define PY_SSIZE_T_CLEAN
9 #include <Python.h>
10 #include <utility>
11 #include <ostream>
12 #include <memory>
13 
14 #define PY_BEGIN try {
15 #define PY_END(v) } catch(mpy::exception_set & err) { return (v); }
16 
17 #if PY_VERSION_HEX < 0x03080000
18     #define PY_VECTORCALL _PyObject_FastCallKeywords
19 #else
20     #define PY_VECTORCALL _PyObject_Vectorcall
21 #endif
22 
23 struct irange {
24  public:
irangeirange25     irange(int64_t end)
26     : irange(0, end, 1) {}
27     irange(int64_t begin, int64_t end, int64_t step = 1)
begin_irange28     : begin_(begin), end_(end), step_(step) {}
29     int64_t operator*() const {
30         return begin_;
31     }
32     irange& operator++() {
33         begin_ += step_;
34         return *this;
35     }
36     bool operator!=(const irange& other) {
37         return begin_ != other.begin_;
38     }
beginirange39     irange begin() {
40         return *this;
41     }
endirange42     irange end() {
43         return irange {end_, end_, step_};
44     }
45  private:
46     int64_t begin_;
47     int64_t end_;
48     int64_t step_;
49 };
50 
51 namespace mpy {
52 
53 struct exception_set {
54 };
55 
56 struct object;
57 struct vector_args;
58 
59 struct handle {
handlehandle60     handle(PyObject* ptr)
61     : ptr_(ptr) {}
62     handle() = default;
63 
64 
ptrhandle65     PyObject* ptr() const {
66         return ptr_;
67     }
68     object attr(const char* key);
69     bool hasattr(const char* key);
typehandle70     handle type() const {
71         return (PyObject*) Py_TYPE(ptr());
72     }
73 
74     template<typename... Args>
75     object call(Args&&... args);
76     object call_object(mpy::handle args);
77     object call_object(mpy::handle args, mpy::handle kwargs);
78     object call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames);
79     object call_vector(vector_args args);
80     bool operator==(handle rhs) {
81         return ptr_ == rhs.ptr_;
82     }
83 
checkedhandle84     static handle checked(PyObject* ptr) {
85         if (!ptr) {
86             throw exception_set();
87         }
88         return ptr;
89     }
90 
91 protected:
92     PyObject* ptr_ = nullptr;
93 };
94 
95 
96 template<typename T>
97 struct obj;
98 
99 template<typename T>
100 struct hdl : public handle {
ptrhdl101     T* ptr() {
102         return  (T*) handle::ptr();
103     }
104     T* operator->() {
105         return ptr();
106     }
hdlhdl107     hdl(T* ptr)
108     : hdl((PyObject*) ptr) {}
hdlhdl109     hdl(const obj<T>& o)
110     : hdl(o.ptr()) {}
111 private:
hdlhdl112     hdl(handle h) : handle(h) {}
113 };
114 
115 struct object : public handle {
116     object() = default;
objectobject117     object(const object& other)
118     : handle(other.ptr_) {
119         Py_XINCREF(ptr_);
120     }
objectobject121     object(object&& other) noexcept
122     : handle(other.ptr_) {
123         other.ptr_ = nullptr;
124     }
125     object& operator=(const object& other) {
126         return *this = object(other);
127     }
128     object& operator=(object&& other) noexcept {
129         PyObject* tmp = ptr_;
130         ptr_ = other.ptr_;
131         other.ptr_ = tmp;
132         return *this;
133     }
~objectobject134     ~object() {
135         Py_XDECREF(ptr_);
136     }
stealobject137     static object steal(handle o) {
138         return object(o.ptr());
139     }
checked_stealobject140     static object checked_steal(handle o) {
141         if (!o.ptr()) {
142             throw exception_set();
143         }
144         return steal(o);
145     }
borrowobject146     static object borrow(handle o) {
147         Py_XINCREF(o.ptr());
148         return steal(o);
149     }
releaseobject150     PyObject* release() {
151         auto tmp = ptr_;
152         ptr_ = nullptr;
153         return tmp;
154     }
155 protected:
objectobject156     explicit object(PyObject* ptr)
157     : handle(ptr) {}
158 };
159 
160 template<typename T>
161 struct obj : public object {
162     obj() = default;
objobj163     obj(const obj& other)
164     : object(other.ptr_) {
165         Py_XINCREF(ptr_);
166     }
objobj167     obj(obj&& other) noexcept
168     : object(other.ptr_) {
169         other.ptr_ = nullptr;
170     }
171     obj& operator=(const obj& other) {
172         return *this = obj(other);
173     }
174     obj& operator=(obj&& other) noexcept {
175         PyObject* tmp = ptr_;
176         ptr_ = other.ptr_;
177         other.ptr_ = tmp;
178         return *this;
179     }
stealobj180     static obj steal(hdl<T> o) {
181         return obj(o.ptr());
182     }
checked_stealobj183     static obj checked_steal(hdl<T> o) {
184         if (!o.ptr()) {
185             throw exception_set();
186         }
187         return steal(o);
188     }
borrowobj189     static obj borrow(hdl<T> o) {
190         Py_XINCREF(o.ptr());
191         return steal(o);
192     }
ptrobj193     T* ptr() const {
194         return (T*) object::ptr();
195     }
196     T* operator->() {
197         return ptr();
198     }
199 protected:
objobj200     explicit obj(T* ptr)
201     : object((PyObject*)ptr) {}
202 };
203 
204 
isinstance(handle h,handle c)205 static bool isinstance(handle h, handle c) {
206     return PyObject_IsInstance(h.ptr(), c.ptr());
207 }
208 
raise_error(handle exception,const char * format,...)209 [[ noreturn ]] inline void raise_error(handle exception, const char *format, ...) {
210     va_list args;
211     va_start(args, format);
212     PyErr_FormatV(exception.ptr(), format, args);
213     va_end(args);
214     throw exception_set();
215 }
216 
217 template<typename T>
218 struct base {
219     PyObject_HEAD
ptrbase220     PyObject* ptr() const {
221         return (PyObject*) this;
222     }
223     static obj<T> alloc(PyTypeObject* type = nullptr) {
224         if (!type) {
225             type = &T::Type;
226         }
227         auto self = (T*) type->tp_alloc(type, 0);
228         if (!self) {
229             throw mpy::exception_set();
230         }
231         new (self) T;
232         return obj<T>::steal(self);
233     }
234     template<typename ... Args>
createbase235     static obj<T> create(Args ... args) {
236         auto self = alloc();
237         self->init(std::forward<Args>(args)...);
238         return self;
239     }
checkbase240     static bool check(handle v) {
241         return isinstance(v, (PyObject*)&T::Type);
242     }
243 
unchecked_wrapbase244     static hdl<T> unchecked_wrap(handle self_) {
245         return hdl<T>((T*)self_.ptr());
246     }
wrapbase247     static hdl<T> wrap(handle self_) {
248         if (!check(self_)) {
249             raise_error(PyExc_ValueError, "not an instance of %S", &T::Type);
250         }
251         return unchecked_wrap(self_);
252     }
253 
unchecked_wrapbase254     static obj<T> unchecked_wrap(object self_) {
255         return obj<T>::steal(unchecked_wrap(self_.release()));
256     }
wrapbase257     static obj<T> wrap(object self_) {
258         return obj<T>::steal(wrap(self_.release()));
259     }
260 
new_stubbase261     static PyObject* new_stub(PyTypeObject *type, PyObject *args, PyObject *kwds) {
262         PY_BEGIN
263         return (PyObject*) alloc(type).release();
264         PY_END(nullptr)
265     }
dealloc_stubbase266     static void dealloc_stub(PyObject *self) {
267         ((T*)self)->~T();
268         Py_TYPE(self)->tp_free(self);
269     }
readybase270     static void ready(mpy::handle mod, const char* name) {
271         if (PyType_Ready(&T::Type)) {
272             throw exception_set();
273         }
274         if(PyModule_AddObject(mod.ptr(), name, (PyObject*) &T::Type) < 0) {
275             throw exception_set();
276         }
277     }
278 };
279 
attr(const char * key)280 inline object handle::attr(const char* key) {
281     return object::checked_steal(PyObject_GetAttrString(ptr(), key));
282 }
283 
hasattr(const char * key)284 inline bool handle::hasattr(const char* key) {
285     return PyObject_HasAttrString(ptr(), key);
286 }
287 
import(const char * module)288 inline object import(const char* module) {
289     return object::checked_steal(PyImport_ImportModule(module));
290 }
291 
292 template<typename... Args>
call(Args &&...args)293 inline object handle::call(Args&&... args) {
294     return object::checked_steal(PyObject_CallFunctionObjArgs(ptr_, args.ptr()..., nullptr));
295 }
296 
call_object(mpy::handle args)297 inline object handle::call_object(mpy::handle args) {
298     return object::checked_steal(PyObject_CallObject(ptr(), args.ptr()));
299 }
300 
301 
call_object(mpy::handle args,mpy::handle kwargs)302 inline object handle::call_object(mpy::handle args, mpy::handle kwargs) {
303     return object::checked_steal(PyObject_Call(ptr(), args.ptr(), kwargs.ptr()));
304 }
305 
call_vector(mpy::handle * begin,Py_ssize_t nargs,mpy::handle kwnames)306 inline object handle::call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames) {
307     return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) begin, nargs, kwnames.ptr()));
308 }
309 
310 struct tuple : public object {
settuple311     void set(int i, object v) {
312         PyTuple_SET_ITEM(ptr_, i, v.release());
313     }
tupletuple314     tuple(int size)
315     : object(checked_steal(PyTuple_New(size))) {}
316 };
317 
318 struct list : public object {
setlist319     void set(int i, object v) {
320         PyList_SET_ITEM(ptr_, i, v.release());
321     }
listlist322     list(int size)
323     : object(checked_steal(PyList_New(size))) {}
324 };
325 
326 namespace{
unicode_from_format(const char * format,...)327 mpy::object unicode_from_format(const char* format, ...) {
328     va_list args;
329     va_start(args, format);
330     auto r = PyUnicode_FromFormatV(format, args);
331     va_end(args);
332     return mpy::object::checked_steal(r);
333 }
unicode_from_string(const char * str)334 mpy::object unicode_from_string(const char * str) {
335     return mpy::object::checked_steal(PyUnicode_FromString(str));
336 }
337 
from_int(Py_ssize_t s)338 mpy::object from_int(Py_ssize_t s) {
339     return mpy::object::checked_steal(PyLong_FromSsize_t(s));
340 }
from_bool(bool b)341 mpy::object from_bool(bool b) {
342     return mpy::object::borrow(b ? Py_True : Py_False);
343 }
344 
is_sequence(handle h)345 bool is_sequence(handle h) {
346     return PySequence_Check(h.ptr());
347 }
348 }
349 
350 struct sequence_view : public handle {
sequence_viewsequence_view351     sequence_view(handle h)
352     : handle(h) {}
sizesequence_view353     Py_ssize_t size() const {
354         auto r = PySequence_Size(ptr());
355         if (r == -1 && PyErr_Occurred()) {
356             throw mpy::exception_set();
357         }
358         return r;
359     }
enumeratesequence_view360     irange enumerate() const {
361         return irange(size());
362     }
wrapsequence_view363     static sequence_view wrap(handle h) {
364         if (!is_sequence(h)) {
365             raise_error(PyExc_ValueError, "expected a sequence");
366         }
367         return sequence_view(h);
368     }
369     mpy::object operator[](Py_ssize_t i) const {
370         return mpy::object::checked_steal(PySequence_GetItem(ptr(), i));
371     }
372 };
373 
374 namespace {
repr(handle h)375 mpy::object repr(handle h) {
376     return mpy::object::checked_steal(PyObject_Repr(h.ptr()));
377 }
378 
str(handle h)379 mpy::object str(handle h) {
380     return mpy::object::checked_steal(PyObject_Str(h.ptr()));
381 }
382 
383 
is_int(handle h)384 bool is_int(handle h) {
385     return PyLong_Check(h.ptr());
386 }
387 
is_none(handle h)388 bool is_none(handle h) {
389     return h.ptr() == Py_None;
390 }
391 
to_int(handle h)392 Py_ssize_t to_int(handle h) {
393     Py_ssize_t r = PyLong_AsSsize_t(h.ptr());
394     if (r == -1 && PyErr_Occurred()) {
395         throw mpy::exception_set();
396     }
397     return r;
398 }
399 
to_bool(handle h)400 bool to_bool(handle h) {
401     return PyObject_IsTrue(h.ptr()) != 0;
402 }
403 }
404 
405 struct slice_view {
slice_viewslice_view406     slice_view(handle h, Py_ssize_t size)  {
407         if(PySlice_Unpack(h.ptr(), &start, &stop, &step) == -1) {
408             throw mpy::exception_set();
409         }
410         slicelength = PySlice_AdjustIndices(size, &start, &stop, step);
411     }
412     Py_ssize_t start, stop, step, slicelength;
413 };
414 
is_slice(handle h)415 static bool is_slice(handle h) {
416     return PySlice_Check(h.ptr());
417 }
418 
419 inline std::ostream& operator<<(std::ostream& ss, handle h) {
420     ss << PyUnicode_AsUTF8(str(h).ptr());
421     return ss;
422 }
423 
424 struct tuple_view : public handle {
425     tuple_view() = default;
tuple_viewtuple_view426     tuple_view(handle h) : handle(h) {}
427 
sizetuple_view428     Py_ssize_t size() const {
429         return PyTuple_GET_SIZE(ptr());
430     }
431 
enumeratetuple_view432     irange enumerate() const {
433         return irange(size());
434     }
435 
436     handle operator[](Py_ssize_t i) {
437         return PyTuple_GET_ITEM(ptr(), i);
438     }
439 
checktuple_view440     static bool check(handle h) {
441         return PyTuple_Check(h.ptr());
442     }
443 };
444 
445 struct list_view : public handle {
446     list_view() = default;
list_viewlist_view447     list_view(handle h) : handle(h) {}
sizelist_view448     Py_ssize_t size() const {
449         return PyList_GET_SIZE(ptr());
450     }
451 
enumeratelist_view452     irange enumerate() const {
453         return irange(size());
454     }
455 
456     handle operator[](Py_ssize_t i) {
457         return PyList_GET_ITEM(ptr(), i);
458     }
459 
checklist_view460     static bool check(handle h) {
461         return PyList_Check(h.ptr());
462     }
463 };
464 
465 struct dict_view : public handle {
466     dict_view() = default;
dict_viewdict_view467     dict_view(handle h) : handle(h) {}
keysdict_view468     object keys() const {
469         return mpy::object::checked_steal(PyDict_Keys(ptr()));
470     }
valuesdict_view471     object values() const {
472         return mpy::object::checked_steal(PyDict_Values(ptr()));
473     }
itemsdict_view474     object items() const {
475         return mpy::object::checked_steal(PyDict_Items(ptr()));
476     }
containsdict_view477     bool contains(handle k) const {
478         return PyDict_Contains(ptr(), k.ptr());
479     }
480     handle operator[](handle k) {
481         return mpy::handle::checked(PyDict_GetItem(ptr(), k.ptr()));
482     }
checkdict_view483     static bool check(handle h) {
484         return PyDict_Check(h.ptr());
485     }
nextdict_view486     bool next(Py_ssize_t* pos, mpy::handle* key, mpy::handle* value) {
487         PyObject *k = nullptr, *v = nullptr;
488         auto r = PyDict_Next(ptr(), pos, &k, &v);
489         *key = k;
490         *value = v;
491         return r;
492     }
setdict_view493     void set(handle k, handle v) {
494         if (-1 == PyDict_SetItem(ptr(), k.ptr(), v.ptr())) {
495             throw exception_set();
496         }
497     }
498 };
499 
500 
501 struct kwnames_view : public handle {
502     kwnames_view() = default;
kwnames_viewkwnames_view503     kwnames_view(handle h) : handle(h) {}
504 
sizekwnames_view505     Py_ssize_t size() const {
506         return PyTuple_GET_SIZE(ptr());
507     }
508 
enumeratekwnames_view509     irange enumerate() const {
510         return irange(size());
511     }
512 
513     const char* operator[](Py_ssize_t i) const {
514         PyObject* obj = PyTuple_GET_ITEM(ptr(), i);
515         return PyUnicode_AsUTF8(obj);
516     }
517 
checkkwnames_view518     static bool check(handle h) {
519         return PyTuple_Check(h.ptr());
520     }
521 };
522 
funcname(mpy::handle func)523 inline mpy::object funcname(mpy::handle func) {
524     if (func.hasattr("__name__")) {
525         return func.attr("__name__");
526     } else {
527         return mpy::str(func);
528     }
529 }
530 
531 struct vector_args {
vector_argsvector_args532     vector_args(PyObject *const *a,
533                       Py_ssize_t n,
534                       PyObject *k)
535     : vector_args((mpy::handle*)a, n, k) {}
vector_argsvector_args536     vector_args(mpy::handle* a,
537                     Py_ssize_t n,
538                     mpy::handle k)
539     : args((mpy::handle*)a), nargs(n), kwnames(k) {}
540     mpy::handle* args;
541     Py_ssize_t nargs;
542     kwnames_view kwnames;
543 
beginvector_args544     mpy::handle* begin() {
545         return args;
546     }
endvector_args547     mpy::handle* end() {
548         return args + size();
549     }
550 
551     mpy::handle operator[](int64_t i) const {
552         return args[i];
553     }
has_keywordsvector_args554     bool has_keywords() const {
555         return kwnames.ptr();
556     }
enumerate_positionalvector_args557     irange enumerate_positional() {
558         return irange(nargs);
559     }
enumerate_allvector_args560     irange enumerate_all() {
561         return irange(size());
562     }
sizevector_args563     int64_t size() const {
564         return nargs + (has_keywords() ? kwnames.size() : 0);
565     }
566 
567     // bind a test function so this can be tested, first two args for required/kwonly, then return what was parsed...
568 
569     // provide write kwarg
570     // don't provide a required arg
571     // don't provide an optional arg
572     // provide a kwarg that is the name of already provided positional
573     // provide a kwonly argument positionally
574     // provide keyword arguments in the wrong order
575     // provide only keyword arguments
576     void parse(const char * fname_cstr, std::initializer_list<const char*> names, std::initializer_list<mpy::handle*> values, int required, int kwonly=0) {
577         auto error = [&]() {
578             // rather than try to match the slower infrastructure with error messages exactly, once we have detected an error, just use that
579             // infrastructure to format it and throw it
580 
581             // have to leak this, because python expects these to last
582             const char** names_buf = new const char*[names.size() + 1];
583             std::copy(names.begin(), names.end(), &names_buf[0]);
584             names_buf[names.size()] = nullptr;
585 
586 #if PY_VERSION_HEX < 0x03080000
587             char* format_str = new char[names.size() + 3];
588             int i = 0;
589             char* format_it = format_str;
590             for (auto it = names.begin(); it != names.end(); ++it, ++i) {
591                 if (i == required) {
592                     *format_it++ = '|';
593                 }
594                 if (i == (int)names.size() - kwonly) {
595                     *format_it++ = '$';
596                 }
597                 *format_it++ = 'O';
598             }
599             *format_it++ = '\0';
600             _PyArg_Parser* _parser = new _PyArg_Parser{format_str, &names_buf[0], fname_cstr, 0};
601             PyObject *dummy = NULL;
602             _PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy);
603 #else
604             _PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0};
605             std::unique_ptr<PyObject*[]> buf(new PyObject*[names.size()]);
606             _PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]);
607 #endif
608             throw exception_set();
609         };
610 
611         auto values_it = values.begin();
612         auto names_it = names.begin();
613         auto npositional = values.size() - kwonly;
614 
615         if (nargs > (Py_ssize_t)npositional) {
616             // TOO MANY ARGUMENTS
617             error();
618         }
619         for (auto i : irange(nargs)) {
620             *(*values_it++) = args[i];
621             ++names_it;
622         }
623 
624         if (!kwnames.ptr()) {
625             if (nargs < required) {
626                 // not enough positional arguments
627                 error();
628             }
629         } else {
630             int consumed = 0;
631             for (auto i : irange(nargs, values.size())) {
632                 bool success = i >= required;
633                 const char* target_name = *(names_it++);
634                 for (auto j : kwnames.enumerate()) {
635                     if (!strcmp(target_name,kwnames[j])) {
636                         *(*values_it) = args[nargs + j];
637                         ++consumed;
638                         success = true;
639                         break;
640                     }
641                 }
642                 ++values_it;
643                 if (!success) {
644                     // REQUIRED ARGUMENT NOT SPECIFIED
645                     error();
646                 }
647             }
648             if (consumed != kwnames.size()) {
649                 // NOT ALL KWNAMES ARGUMENTS WERE USED
650                 error();
651             }
652         }
653     }
indexvector_args654     int index(const char* name, int pos) {
655         if (pos < nargs) {
656             return pos;
657         }
658         if (kwnames.ptr()) {
659             for (auto j : kwnames.enumerate()) {
660                 if (!strcmp(name, kwnames[j])) {
661                     return nargs + j;
662                 }
663             }
664         }
665         return -1;
666     }
667 };
668 
call_vector(vector_args args)669 inline object handle::call_vector(vector_args args) {
670     return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) args.args, args.nargs, args.kwnames.ptr()));
671 }
672 
673 
674 }
675 
676 #define MPY_ARGS_NAME(typ, name) #name ,
677 #define MPY_ARGS_DECLARE(typ, name) typ name;
678 #define MPY_ARGS_POINTER(typ, name) &name ,
679 #define MPY_PARSE_ARGS_KWARGS(fmt, FORALL_ARGS) \
680     static char* kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
681     FORALL_ARGS(MPY_ARGS_DECLARE) \
682     if (!PyArg_ParseTupleAndKeywords(args, kwargs, fmt, kwlist, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
683         throw mpy::exception_set(); \
684     }
685 
686 #define MPY_PARSE_ARGS_KWNAMES(fmt, FORALL_ARGS) \
687     static const char * const kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
688     FORALL_ARGS(MPY_ARGS_DECLARE) \
689     static _PyArg_Parser parser = {fmt, kwlist, 0}; \
690     if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
691         throw mpy::exception_set(); \
692     }
693