xref: /aosp_15_r20/external/pytorch/functorch/csrc/dim/dim.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <torch/csrc/utils/python_compat.h>
8 
9 
10 // Many APIs have changed/don't exist anymore
11 #if IS_PYTHON_3_12_PLUS
12 
13 #include "dim.h"
14 
15 // Re-enable this some day
Dim_init()16 PyObject* Dim_init() {
17     PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12");
18     return nullptr;
19 }
20 
21 #else
22 
23 #include "minpybind.h"
24 #include <frameobject.h>
25 #include <opcode.h>
26 #include <utility>
27 #include <new>
28 #include <iostream>
29 #include <vector>
30 //#include <torch/csrc/autograd/python_variable.h>
31 #include <torch/csrc/Export.h>
32 #include <ATen/functorch/BatchedTensorImpl.h>
33 #include <ATen/functorch/DynamicLayer.h>
34 #include <ATen/ATen.h>
35 #include <memory>
36 #include "arena.h"
37 #include "dim.h"
38 #include "python_variable_simple.h"
39 
40 #if IS_PYTHON_3_11_PLUS
41 #define Py_BUILD_CORE
42 #include "internal/pycore_opcode.h"
43 #undef Py_BUILD_CORE
44 #endif
45 
46 // C++ API functions for objects to
47 // * construct the object, returning a ref-counted handle
48 // * The actual API, with methods that take/return C-typed values
49 
50 // extend minpybind.h to include
51 // * typed handles so that -> can get to their raw API
52 // * object/handle distinction for the typed handles
53 
54 // class Dim: ---------------
55 mpy::handle torch_Tensor___mul__;
56 mpy::handle _Tensor;
57 mpy::handle _Tensor_sum;
58 mpy::handle NamedTuple;
59 mpy::dict_view pointwise;
60 mpy::handle torch_Tensor_expand;
61 binaryfunc THPVariable_getitem;
62 objobjargproc THPVariable_setitem;
63 mpy::handle no_slice;
64 PyTypeObject* torch_Tensor;
65 mpy::handle torch_Tensor_copy_;
66 mpy::handle torch_Tensor_split;
67 bool pointwise_optimize = true;
68 PyTypeObject* DimType = nullptr;
69 
70 PyObject* Tensor_getitem(PyObject* self, PyObject* index);
71 int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value);
72 
73 namespace{
maybeInitializeGlobals()74 void maybeInitializeGlobals() {
75     // globals that depend on the python dim library,
76     // which we can't lookup until we finish initializing the _C module
77     if (_Tensor.ptr()) {
78         return;
79     }
80     auto dim = mpy::import("functorch.dim");
81     _Tensor = dim.attr("_Tensor");
82     pointwise = dim.attr("pointwise");
83     _Tensor_sum = _Tensor.attr("sum");
84     DimType = (PyTypeObject*) mpy::import("functorch.dim").attr("Dim").ptr();
85 }
86 
replaceMappingIfMatches(mpy::handle tp)87 void replaceMappingIfMatches(mpy::handle tp) {
88     auto T = (PyTypeObject*) tp.ptr();
89     bool recurse = false;
90     if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) {
91         T->tp_as_mapping->mp_subscript = Tensor_getitem;
92         recurse = true;
93     }
94     if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) {
95         T->tp_as_mapping->mp_ass_subscript = Tensor_setitem;
96         recurse = true;
97     }
98     if (recurse) {
99         auto result = tp.attr("__subclasses__").call();
100         mpy::list_view lv(result);
101         for (auto i : lv.enumerate()) {
102             replaceMappingIfMatches(lv[i]);
103         }
104     }
105 }
106 
initializeGlobals(Arena & A)107 void initializeGlobals(Arena & A) {
108     auto torch = mpy::import("torch");
109     torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr();
110     torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__");
111 
112     torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand");
113     torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split");
114     torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_");
115     auto py_TensorBase = torch.attr("_C").attr("TensorBase");
116     auto TensorBase = (PyTypeObject*) py_TensorBase.ptr();
117     THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript;
118     THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript;
119     NamedTuple = mpy::import("typing").attr("NamedTuple");
120     no_slice = PySlice_New(NULL, NULL, NULL);
121 
122 }
123 
124 mpy::handle DimensionBindError_;
DimensionBindError()125 mpy::handle DimensionBindError() {
126     if(!DimensionBindError_.ptr()) {
127         DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError");
128     }
129     return DimensionBindError_;
130 }
131 
132 static int64_t n_dims_created = 65;
133 
134 struct Dim : public mpy::base<Dim> {
135     int64_t level_; // for stable comparisons in prototype
136     mpy::object name_;
Dim__anon5de4c5480111::Dim137     Dim()
138     : level_(n_dims_created++) {}
init__anon5de4c5480111::Dim139     void init(mpy::object name, int64_t s = -1) {
140         name_ = std::move(name);
141         size_ = s;
142     }
143 
check_exact__anon5de4c5480111::Dim144     static bool check_exact(mpy::handle v) {
145         return Py_TYPE(v.ptr()) == DimType;
146     }
147 
size__anon5de4c5480111::Dim148     int64_t size() const {
149         if (size_ == -1) {
150             mpy::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr());
151         }
152         return size_;
153     }
set_size__anon5de4c5480111::Dim154     void set_size(int64_t v) {
155         if (size_ == -1) {
156             size_ = v;
157         } else if(size_ != v) {
158             mpy::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v);
159         }
160     }
is_bound__anon5de4c5480111::Dim161     bool is_bound() const {
162         return size_ != -1;
163     }
create__anon5de4c5480111::Dim164     static mpy::obj<Dim> create(mpy::object name, int64_t s = -1) {
165         if (!DimType) {
166             maybeInitializeGlobals();
167         }
168         auto r = Dim::alloc(DimType);
169         r->init(std::move(name), s);
170         return r;
171     }
172     static PyTypeObject Type;
range__anon5de4c5480111::Dim173     const at::Tensor& range() {
174         if (!range_.defined()) {
175             range_ = at::arange(size());
176         }
177         return range_;
178     }
batchtensor__anon5de4c5480111::Dim179     const at::Tensor& batchtensor() {
180         if (!batchtensor_.defined()) {
181             batchtensor_ = at::functorch::addBatchDim(range(), 0, level_);
182         }
183         return batchtensor_;
184     }
185 private:
186     int64_t size_{-1};
187     at::Tensor range_;
188     at::Tensor batchtensor_;
189 };
190 
191 
192 struct DimEntry {
193     // union of either a negative number indicating which dimension this is from the rhs,
194     // or a pointer to a first-class dimension.
195     // pointers do not have their highest bit set, so checking the number is negative tells us
196     // that it is not a dim.
is_positional__anon5de4c5480111::DimEntry197     bool is_positional() const {
198         return data_ < 0;
199     }
is_none__anon5de4c5480111::DimEntry200     bool is_none() const {
201         return data_ == 0;
202     }
position__anon5de4c5480111::DimEntry203     int64_t position() const {
204         return data_;
205     }
dim__anon5de4c5480111::DimEntry206     mpy::hdl<Dim> dim() const {
207         Dim* result;
208         std::memcpy(&result, &data_, sizeof(Dim*));
209         return mpy::hdl<Dim>(result);
210     }
211 
DimEntry__anon5de4c5480111::DimEntry212     DimEntry()
213     : data_(0) {}
214 
DimEntry__anon5de4c5480111::DimEntry215     DimEntry(int64_t pos)
216     : data_(pos) {
217         AT_ASSERT(pos < 0);
218     }
DimEntry__anon5de4c5480111::DimEntry219     DimEntry(mpy::hdl<Dim> d) {
220        std::memcpy(&data_, &d, sizeof(int64_t));
221     }
operator ==__anon5de4c5480111::DimEntry222     bool operator==(const DimEntry& rhs) const {
223         return data_ == rhs.data_;
224     }
225 private:
226     int64_t data_;
227 };
228 
229 // Dim wrapper methods
_wrap_dim(mpy::handle d,size_t N,bool keepdim)230 DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) {
231     if (Dim::check(d)) {
232         if (keepdim) {
233             mpy::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True");
234         }
235         return Dim::unchecked_wrap(d);
236     } else if (mpy::is_int(d)) {
237         auto i = mpy::to_int(d);
238         while (i >= 0) {
239             i -= N;
240         }
241         return i;
242     } else {
243         return DimEntry();
244     }
245 }
246 
247 
Dim_init(mpy::hdl<Dim> self,PyObject * args,PyObject * kwds)248 int Dim_init(mpy::hdl<Dim> self, PyObject *args, PyObject *kwds) {
249     PY_BEGIN
250     static constexpr const char* kwlist[] = {"name", "size", nullptr};
251     mpy::handle name;
252     mpy::handle size = nullptr;
253     if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", const_cast<char **>(kwlist), &name, &size)) {
254         return -1;
255     }
256     self->init(mpy::object::borrow(name), (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1);
257     return 0;
258     PY_END(-1)
259 }
260 
Dim_repr(Dim * self)261 PyObject* Dim_repr(Dim* self) {
262     PY_BEGIN
263     mpy::object name = (self->name_.ptr()) ? self->name_ : mpy::unicode_from_string("<uninitialized dim>");
264     return name.release();
265     PY_END(nullptr)
266 }
267 
268 
Dim_getsize(Dim * self,void *)269 PyObject* Dim_getsize(Dim* self, void*) {
270     PY_BEGIN
271     return mpy::from_int(self->size()).release();
272     PY_END(nullptr)
273 }
274 
Dim_setsize(Dim * self,PyObject * size,void *)275 int Dim_setsize(Dim* self, PyObject* size, void*) {
276     PY_BEGIN
277     self->set_size(mpy::to_int(size));
278     return 0;
279     PY_END(-1)
280 }
281 
Dim_getis_bound(Dim * self,void *)282 PyObject* Dim_getis_bound(Dim* self, void*) {
283     return PyBool_FromLong(self->is_bound());
284 }
285 
Dim_getlevel(Dim * self,void *)286 PyObject* Dim_getlevel(Dim* self, void*) {
287     return PyLong_FromLong(self->level_);
288 }
289 
Dim_get_levels(Dim * self,void *)290 PyObject* Dim_get_levels(Dim* self, void*) {
291     mpy::tuple t(1);
292     t.set(0, mpy::object::borrow(self->ptr()));
293     return t.release();
294 }
295 
Dim_get_has_device(Dim * self,void *)296 PyObject* Dim_get_has_device(Dim* self, void*) {
297     Py_RETURN_FALSE;
298 }
299 
Dim_get_tensor(Dim * self,void *)300 PyObject* Dim_get_tensor(Dim* self, void*) {
301     return THPVariable_Wrap(self->range());
302 }
303 
Dim_get_batchtensor(Dim * self,void *)304 PyObject* Dim_get_batchtensor(Dim* self, void*) {
305     return THPVariable_Wrap(self->batchtensor());
306 }
307 
308 
309 PyGetSetDef Dim_getsetters[] = {
310     {"size", (getter) Dim_getsize, (setter) Dim_setsize,
311      "Dimension size", NULL},
312     {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL},
313     {"_level", (getter) Dim_getlevel, NULL, "_level", NULL},
314     {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL},
315     {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL},
316     {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL},
317     {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL},
__anon5de4c5480202() 318     {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_int(1).release(); }, NULL, "ndim", NULL},
319     {NULL}  /* Sentinel */
320 };
321 }
322 PyTypeObject Dim::Type = {
323     PyVarObject_HEAD_INIT(NULL, 0)
324     "_C.Dim",               /* tp_name */
325     sizeof(Dim),               /* tp_basicsize */
326     0,                              /* tp_itemsize */
327     Dim::dealloc_stub,      /* tp_dealloc */
328     0,                              /* tp_vectorcall_offset */
329     0,                              /* tp_getattr */
330     0,                              /* tp_setattr */
331     0,                              /* tp_as_async */
332     (reprfunc)Dim_repr,           /* tp_repr */
333     0,                 /* tp_as_number */
334     0,                              /* tp_as_sequence */
335     0,                              /* tp_as_mapping */
336     0,      /* tp_hash */
337     0,                              /* tp_call */
338     0,                              /* tp_str */
339     0,                              /* tp_getattro */
340     0,                              /* tp_setattro */
341     0,                              /* tp_as_buffer */
342     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,  /* tp_flags */
343     "Dim Object",                   /* tp_doc */
344     0,                              /* tp_traverse */
345     0,                              /* tp_clear */
346     0,  /* tp_richcompare */
347     0,                              /* tp_weaklistoffset */
348     0,                              /* tp_iter */
349     0,                              /* tp_iternext */
350     0,                              /* tp_methods */
351     0,                              /* tp_members */
352     Dim_getsetters,                 /* tp_getset */
353     0,                              /* tp_base */
354     0,                              /* tp_dict */
355     0,                              /* tp_descr_get */
356     0,                              /* tp_descr_set */
357     0,                              /* tp_dictoffset */
358     (initproc)(void*)static_cast<int(*)(mpy::hdl<Dim>,PyObject*,PyObject*)>(Dim_init),      /* tp_init */
359     0,                              /* tp_alloc */
360     Dim::new_stub,                      /* tp_new */
361 };
362 
363 // class DimList ------------
364 
365 struct DimList : public mpy::base<DimList> {
366     mpy::object name_;
367     std::vector<mpy::obj<Dim>> dims_;
368     static PyTypeObject Type;
initDimList369     void init(mpy::object name) {
370         name_ = std::move(name);
371     }
set_dimsDimList372     void set_dims(std::vector<mpy::obj<Dim>> dims) {
373         bound_ = true;
374         dims_ = std::move(dims);
375     }
is_boundDimList376     bool is_bound() {
377         return bound_;
378     }
bind_lenDimList379     void bind_len(int64_t size) {
380         if (bound_) {
381             int64_t b_size = dims_.size();
382             if (b_size != size) {
383                 mpy::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size);
384             }
385         } else {
386             bound_ = true;
387             dims_.resize(size);
388             for (Py_ssize_t i = 0; i < size; ++i) {
389                 dims_[i] = Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i));
390             }
391         }
392     }
sizeDimList393     int64_t size() const {
394         if (!bound_) {
395             mpy::raise_error(DimensionBindError(), "DimList not bound");
396         }
397         return dims_.size();
398     }
set_boundDimList399     void set_bound(bool b) {
400         bound_ = b;
401     }
402 private:
403     bool bound_ = false;
404 };
405 
406 
407 static int DimList_init(DimList *self, PyObject *args, PyObject *kwds);
408 
DimList_repr(DimList * self)409 static PyObject* DimList_repr(DimList* self) {
410     PY_BEGIN
411     if (self->is_bound()) {
412         size_t size = self->dims_.size();
413         mpy::tuple t(size);
414         for(size_t i = 0; i < size; ++i) {
415             t.set(i, self->dims_[i]);
416         }
417         return mpy::repr(t).release();
418     } else if(!mpy::is_none(self->name_)) {
419         return mpy::unicode_from_format("*%S", self->name_.ptr()).release();
420     } else {
421         return mpy::unicode_from_string("<unbound_dimlist>").release();
422     }
423     PY_END(nullptr)
424 }
425 
DimList_bind(DimList * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)426 static PyObject* DimList_bind(DimList *self,
427                       PyObject *const *args,
428                       Py_ssize_t nargs,
429                       PyObject *kwnames) {
430     PY_BEGIN
431     mpy::handle sizes;
432     static const char * const _keywords[] = {"sizes", nullptr};
433     static _PyArg_Parser parser = {"O", _keywords, 0};
434     if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) {
435         return nullptr;
436     }
437     if (!mpy::is_sequence(sizes)) {
438         mpy::raise_error(PyExc_ValueError, "expected a sequence");
439     }
440     mpy::sequence_view seq = sizes;
441     auto size = seq.size();
442     self->bind_len(size);
443     for (Py_ssize_t i = 0; i < size; ++i) {
444         self->dims_[i]->set_size(mpy::to_int(seq[i]));
445     }
446     Py_RETURN_NONE;
447     PY_END(nullptr)
448 }
449 
DimList_bind_len(DimList * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)450 static PyObject* DimList_bind_len(DimList *self,
451                       PyObject *const *args,
452                       Py_ssize_t nargs,
453                       PyObject *kwnames) {
454     PY_BEGIN
455     int size;
456     static const char * const _keywords[] = {"N", nullptr};
457     static _PyArg_Parser parser = {"i", _keywords, 0};
458     if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) {
459         return nullptr;
460     }
461     self->bind_len(size);
462     Py_RETURN_NONE;
463     PY_END(nullptr)
464 }
465 
466 static PyMethodDef DimList_methods[] = {
467     {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS},
468     {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS},
469     {NULL, NULL, 0, NULL}        /* Sentinel */
470 };
471 
472 
DimList_len(DimList * self)473 static Py_ssize_t DimList_len(DimList* self) {
474     PY_BEGIN
475     return self->size();
476     PY_END(-1)
477 }
478 
DimList_item(DimList * self,Py_ssize_t idx)479 static PyObject * DimList_item(DimList* self, Py_ssize_t idx) {
480     PY_BEGIN
481     if (!self->is_bound()) {
482         mpy::raise_error(DimensionBindError(), "DimList not bound");
483     }
484     if (idx < 0 || (size_t) idx >= self->dims_.size()) {
485         mpy::raise_error(PyExc_IndexError, "index out of bounds");
486     }
487     mpy::object r = self->dims_[idx];
488     return r.release();
489     PY_END(nullptr)
490 }
491 
492 PySequenceMethods DimList_seq {
493     (lenfunc) DimList_len, //lenfunc sq_length;
494     0, //binaryfunc sq_concat;
495     0, //ssizeargfunc sq_repeat;
496     (ssizeargfunc) DimList_item, //ssizeargfunc sq_item;
497     0, //void *was_sq_slice;
498     0, //ssizeobjargproc sq_ass_item;
499     0, //void *was_sq_ass_slice;
500     0, //objobjproc sq_contains;
501 
502     0, //binaryfunc sq_inplace_concat;
503     0, //ssizeargfunc sq_inplace_repeat;
504 };
505 
DimList_getis_bound(DimList * self,void *)506 static PyObject* DimList_getis_bound(DimList* self, void*) {
507     return PyBool_FromLong(self->is_bound());
508 }
509 
510 static PyGetSetDef DimList_getsetters[] = {
511     {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL},
512     {NULL}  /* Sentinel */
513 };
514 
515 
DimList_subscript(DimList * self,mpy::handle idx)516 static PyObject* DimList_subscript(DimList* self, mpy::handle idx) {
517     PY_BEGIN
518     if (mpy::is_int(idx)) {
519         return DimList_item(self, mpy::to_int(idx));
520     } else if (mpy::is_slice(idx)) {
521         if (!self->is_bound()) {
522             mpy::raise_error(DimensionBindError(), "DimList not bound");
523         }
524         mpy::slice_view s(idx, self->dims_.size());
525         mpy::tuple r(s.slicelength);
526         for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) {
527             r.set(j++,  self->dims_[i]);
528         }
529         return r.release();
530     } else {
531         mpy::raise_error(PyExc_ValueError, "expected an int or a slice");
532         return nullptr;
533     }
534     PY_END(nullptr)
535 }
536 
537 PyMappingMethods DimList_mapping = {
538     0, //lenfunc mp_length;
539     (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript;
540     0, //objobjargproc mp_ass_subscript;
541 };
542 
543 
544 
545 PyTypeObject DimList::Type = {
546     PyVarObject_HEAD_INIT(NULL, 0)
547     "_C.DimList",               /* tp_name */
548     sizeof(DimList),               /* tp_basicsize */
549     0,                              /* tp_itemsize */
550     DimList::dealloc_stub,      /* tp_dealloc */
551     0,                              /* tp_vectorcall_offset */
552     0,                              /* tp_getattr */
553     0,                              /* tp_setattr */
554     0,                              /* tp_as_async */
555     (reprfunc)DimList_repr,           /* tp_repr */
556     0,                 /* tp_as_number */
557     &DimList_seq,                 /* tp_as_sequence */
558     &DimList_mapping,             /* tp_as_mapping */
559     0,      /* tp_hash */
560     0,                              /* tp_call */
561     0,                              /* tp_str */
562     0,                              /* tp_getattro */
563     0,                              /* tp_setattro */
564     0,                              /* tp_as_buffer */
565     0,                              /* tp_flags */
566     "DimList Object",                   /* tp_doc */
567     0,                              /* tp_traverse */
568     0,                              /* tp_clear */
569     0,                              /* tp_richcompare */
570     0,                              /* tp_weaklistoffset */
571     0,                              /* tp_iter */
572     0,                              /* tp_iternext */
573     DimList_methods,                /* tp_methods */
574     0,                              /* tp_members */
575     DimList_getsetters,             /* tp_getset */
576     0,                              /* tp_base */
577     0,                              /* tp_dict */
578     0,                              /* tp_descr_get */
579     0,                              /* tp_descr_set */
580     0,                              /* tp_dictoffset */
581     (initproc) DimList_init,            /* tp_init */
582     0,                              /* tp_alloc */
583     DimList::new_stub,                      /* tp_new */
584 };
585 
DimList_init(DimList * self,PyObject * args,PyObject * kwds)586 static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) {
587     PY_BEGIN
588     static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr};
589     mpy::handle len_or_dims = nullptr;
590     PyObject* name = nullptr;
591     if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast<char**>(kwlist), &len_or_dims, &name)) {
592         return -1;
593     }
594     self->init(mpy::object::borrow(name ? name : Py_None));
595     if (len_or_dims.ptr()) {
596         if(mpy::is_int(len_or_dims)) {
597             self->bind_len(mpy::to_int(len_or_dims));
598         } else if (mpy::is_sequence(len_or_dims)) {
599             mpy::sequence_view s(len_or_dims);
600             std::vector<mpy::obj<Dim>> dims;
601             size_t size = s.size();
602             dims.reserve(size);
603             for (size_t i = 0; i < size; ++i) {
604                 auto r = s[i];
605                 if (mpy::is_int(r)) {
606                     dims.emplace_back(Dim::create(mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i),  mpy::to_int(r)));
607                 } else {
608                     dims.emplace_back(Dim::wrap(r));
609                 }
610             }
611             self->set_dims(std::move(dims));
612         } else {
613             PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions");
614             return -1;
615         }
616         return 0;
617     }
618     return 0;
619     PY_END(-1);
620 }
621 
622 // Tensor -----------------------------
623 
624 PyTypeObject* TensorType = nullptr; // the python wrapper type.
625 mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise);
626 
627 namespace{
628 
_add_batch_dims(Arena & A,at::Tensor t,Slice<DimEntry> levels_)629 at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice<DimEntry> levels_) {
630     auto levels = Slice<DimEntry>();
631     levels.extend(A, levels_);
632     while (true) {
633         int64_t min_real_index = -1;
634         int64_t min_index = -1;
635         int64_t min_value = INT_MAX;
636         int64_t i = 0;
637         int64_t r = 0;
638         for (auto l : levels) {
639             if (!l.is_none()) {
640                 if (!l.is_positional() && l.dim()->level_ < min_value) {
641                     min_value = l.dim()->level_;
642                     min_index = i;
643                     min_real_index = r;
644                 }
645                 ++i;
646             }
647             ++r;
648         }
649         if (min_index == -1) {
650             return t;
651         }
652         auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value);
653         t = std::move(t2);
654         levels[min_real_index] = DimEntry();
655     }
656 }
657 
658 
659 
660 struct DelayedOperator {
DelayedOperator__anon5de4c5480311::DelayedOperator661     DelayedOperator(mpy::object o, mpy::vector_args a)
662     : orig(std::move(o)), args(a) {
663         auto all = a.size();
664         // this will outlive the call so
665         // take ownership of temporaries
666         // in vector args
667         auto buf = new mpy::handle[all];
668         memcpy(buf, args.args, sizeof(mpy::handle)*all);
669         args.args = buf;
670         for (auto i : args.enumerate_all()) {
671             Py_INCREF(args.args[i].ptr());
672         }
673         Py_XINCREF(args.kwnames.ptr());
674     }
~DelayedOperator__anon5de4c5480311::DelayedOperator675     ~DelayedOperator() {
676         for (auto i : args.enumerate_all()) {
677             Py_DECREF(args[i].ptr());
678         }
679         if (args.has_keywords()) {
680             Py_XDECREF(args.kwnames.ptr());
681         }
682         delete [] args.args;
683     }
684     mpy::object orig;
685     mpy::vector_args args;
686 };
687 
free_levels_dims(Slice<DimEntry> levels)688 void free_levels_dims(Slice<DimEntry> levels) {
689     for(auto e : levels) {
690         if (!e.is_positional()) {
691             mpy::object::steal(e.dim());
692         }
693     }
694 }
695 }
696 
697 struct Tensor : public mpy::base<Tensor> {
698 private:
699     at::Tensor tensor_;
700     at::Tensor batchtensor_;
701     OwnedSlice<DimEntry> levels_;
702     bool has_device_;
703     std::unique_ptr<DelayedOperator> delayed_;
704 public:
705 
tensorTensor706     at::Tensor& tensor(Arena& A) {
707         if (C10_UNLIKELY(!tensor_.defined())) {
708             AT_ASSERT(delayed_);
709             auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true));
710             tensor_ = t->tensor(A);
711             delayed_.reset();
712             // don't force creation of batch tensor if it wasn't alreay provided.
713             batchtensor_ = t->batchtensor_;
714             AT_ASSERT(levels() == t->levels());
715         }
716         return tensor_;
717     }
batchtensorTensor718     at::Tensor& batchtensor(Arena& A) {
719         if (C10_UNLIKELY(!batchtensor_.defined())) {
720             batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice());
721         }
722         return batchtensor_;
723     }
levelsTensor724     Slice<DimEntry> levels() {
725         return levels_.slice();
726     }
has_deviceTensor727     bool has_device() {
728         return has_device_;
729     }
delayedTensor730     DelayedOperator* delayed() {
731         return delayed_.get();
732     }
733     static PyTypeObject Type;
734 
check_exactTensor735     static bool check_exact(mpy::handle v) {
736        return Py_TYPE(v.ptr()) == TensorType;
737     }
738 
739 
createTensor740     static mpy::obj<Tensor> create() {
741         if (!TensorType) {
742             TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr();
743         }
744         return Tensor::alloc(TensorType);
745     }
capture_levelsTensor746     void capture_levels(Slice<DimEntry> levels) {
747         // grab ownership of the dims inside levels
748         for (auto l : levels) {
749             if (!l.is_positional()) {
750                 mpy::object::borrow(l.dim()).release();
751             }
752         }
753         levels_.set(levels, free_levels_dims);
754     }
755     static mpy::object from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device);
756     static mpy::obj<Tensor> create_delayed(mpy::object op, mpy::vector_args args, Slice<DimEntry> levels, bool has_device);
757     friend struct EnableAllLayers;
758 };
759 
760 namespace{
761 // version in header does a unnecessary refcount +/-
maybeGetBatchedImpl(const at::Tensor & tensor)762 at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) {
763     if (at::functorch::isBatchedTensor(tensor)) {
764         return static_cast<at::functorch::BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
765     }
766     return nullptr;
767 }
768 
unchecked_tensor_from(mpy::handle p)769 TensorRef unchecked_tensor_from(mpy::handle p) {
770     auto v = (THPVariable*) p.ptr();
771     return TensorRef(*v->cdata);
772 }
773 
ndim_of_levels(Slice<DimEntry> levels)774 static int64_t ndim_of_levels(Slice<DimEntry> levels) {
775     int64_t r = 0;
776     for (auto l : levels) {
777         if (l.is_positional()) {
778             ++r;
779         }
780     }
781     return r;
782 }
783 
784 struct TensorInfo {
785     TensorRef tensor;
786     Slice<DimEntry> levels;
787     bool has_device;
788     TensorRef batchedtensor;
ndim__anon5de4c5480411::TensorInfo789     int64_t ndim() const {
790         return ndim_of_levels(levels);
791     }
operator bool__anon5de4c5480411::TensorInfo792     operator bool() const {
793         return tensor;
794     }
795 
create__anon5de4c5480411::TensorInfo796     static TensorInfo create(Arena& A, mpy::handle h, bool ensure_batched=true, bool ensure_present=true) {
797         if (Tensor::check_exact(h)) {
798             auto t = Tensor::unchecked_wrap(h);
799             return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()};
800         } else if (Dim::check_exact(h)) {
801             auto d = Dim::unchecked_wrap(h);
802             return TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()};
803         } else if (THPVariable_Check(h.ptr())) {
804             TensorRef t = unchecked_tensor_from(h);
805             Slice<DimEntry> levels;
806             for (auto i : irange(-t->dim(), 0)) {
807                 levels.append(A, i);
808             }
809             return TensorInfo {t, levels, true, t};
810         } else {
811             if (ensure_present) {
812                 mpy::raise_error(PyExc_ValueError, "expected a tensor object");
813             }
814             return TensorInfo {};
815         }
816     }
817 
818 
819 };
820 
py_Tensor_from_positional(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)821 static PyObject* py_Tensor_from_positional(PyObject *self,
822                       PyObject *const *args,
823                       Py_ssize_t nargs,
824                       PyObject *kwnames) {
825     Arena A;
826     PY_BEGIN
827     #define ARGS(_) _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device)
828     MPY_PARSE_ARGS_KWNAMES("OOp", ARGS)
829     #undef ARGS
830 
831     if (!THPVariable_Check(tensor.ptr())) {
832         mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?");
833     }
834 
835     Slice<DimEntry> levels;
836     mpy::sequence_view sq(py_levels);
837     for (auto i : sq.enumerate()) {
838         mpy::object v = sq[i];
839         if (mpy::is_int(v)) {
840             auto vi = mpy::to_int(v);
841             levels.append(A, vi);
842         } else {
843             auto dim = Dim::wrap(std::move(v));
844             mpy::hdl<Dim> hdim = dim;
845             levels.append(A, hdim);
846         }
847     }
848     return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release();
849     PY_END(nullptr)
850 }
851 }
852 
from_positional(Arena & A,at::Tensor tensor,Slice<DimEntry> levels,bool has_device)853 mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device) {
854     size_t seen_dims = 0;
855     int last = 0;
856     //auto sz = tensor.sizes();
857     for (auto i : levels.enumerate()) {
858         auto l = levels[i];
859         if (l.is_positional()) {
860             AT_ASSERT(last == 0 || last + 1 == l.position());
861             last = l.position();
862         } else {
863             mpy::object::borrow(l.dim()).release();
864             //AT_ASSERT(sz[i] == l.dim()->size());
865             ++seen_dims;
866         }
867     }
868     AT_ASSERT(last == 0 || last == -1);
869     if (!seen_dims) {
870         return mpy::object::steal(THPVariable_Wrap(std::move(tensor)));
871     }
872 
873     mpy::obj<Tensor> self = Tensor::create();
874     self->tensor_ = std::move(tensor);
875     AT_ASSERT(self->tensor_.dim() == levels.size());
876     self->levels_.set(levels, free_levels_dims);
877     self->has_device_ = has_device;
878     mpy::object r = std::move(self);
879     return r;
880 }
881 
882 
create_delayed(mpy::object op,mpy::vector_args args,Slice<DimEntry> levels,bool has_device)883 mpy::obj<Tensor> Tensor::create_delayed(mpy::object op, mpy::vector_args args, Slice<DimEntry> levels, bool has_device) {
884     mpy::obj<Tensor> self = Tensor::create();
885     self->capture_levels(levels);
886     self->has_device_ = has_device;
887     self->delayed_ = std::make_unique<DelayedOperator>(std::move(op), args);
888     return self;
889 }
890 
891 namespace{
slice_to_list(Slice<mpy::handle> h)892 mpy::list slice_to_list(Slice<mpy::handle> h) {
893     mpy::list lst(h.size());
894     for (auto i : h.enumerate()) {
895         lst.set(i, mpy::object::borrow(h[i]));
896     }
897     return lst;
898 }
899 
slice_to_tuple(Slice<mpy::handle> h)900 mpy::tuple slice_to_tuple(Slice<mpy::handle> h) {
901     mpy::tuple lst(h.size());
902     for (auto i : h.enumerate()) {
903         lst.set(i, mpy::object::borrow(h[i]));
904     }
905     return lst;
906 }
907 
908 enum UType {
909     U_ELEM,
910     U_TUPLE_LIKE,
911     U_DICT,
912 };
913 
914 struct Unflatten {
operator ()__anon5de4c5480511::Unflatten915     mpy::object operator()(Slice<mpy::handle>& elements) {
916         mpy::object r;
917         switch (type) {
918             case U_ELEM: {
919                 r = mpy::object::borrow(elements[0]);
920                 elements = elements.slice(1);
921             } break;
922             case U_TUPLE_LIKE: {
923                 mpy::tuple tup(children.size());
924                 for (auto i : children.enumerate()) {
925                     tup.set(i, children[i](elements));
926                 }
927                 r = obj.call(tup);
928             } break;
929             case U_DICT: {
930                 r = mpy::object::checked_steal(PyDict_New());
931                 mpy::dict_view rv(r);
932                 mpy::dict_view d(obj);
933                 Py_ssize_t pos = 0;
934                 mpy::handle k, v;
935                 for (int i = 0; d.next(&pos, &k, &v); ++i) {
936                     rv.set(k, children[i](elements));
937                 }
938             } break;
939         }
940         return r;
941     }
942     UType type;
943     mpy::handle obj;
944     Slice<Unflatten> children;
945 };
946 
tree_flatten(Arena & A,mpy::handle agg,Slice<mpy::handle> & flat_elements)947 Unflatten tree_flatten(Arena& A, mpy::handle agg, Slice<mpy::handle>& flat_elements) {
948     Slice<Unflatten> c;
949     UType utype;
950     mpy::handle obj;
951     if (mpy::list_view::check(agg)) {
952         obj = agg.type();
953         utype = U_TUPLE_LIKE;
954         mpy::list_view l(agg);
955         for (auto i : l.enumerate()) {
956             c.append(A, tree_flatten(A, l[i], flat_elements));
957         }
958     } else if (mpy::tuple_view::check(agg)) {
959         obj = agg.type();
960         utype = U_TUPLE_LIKE;
961         // includes named tuples
962         mpy::tuple_view l(agg);
963         for (auto i : l.enumerate()) {
964             c.append(A, tree_flatten(A, l[i], flat_elements));
965         }
966     } else if (mpy::dict_view::check(agg)) {
967         utype = U_DICT;
968         mpy::dict_view d(agg);
969         obj = agg;
970         Py_ssize_t pos = 0;
971         mpy::handle k, v;
972         while (d.next(&pos, &k, &v)) {
973             c.append(A, tree_flatten(A, v, flat_elements));
974         }
975     } else {
976         utype = U_ELEM;
977         flat_elements.append(A, agg);
978     }
979     return Unflatten {utype, obj, c};
980 }
981 
982 struct UnflattenVectorArgs {
operator ()__anon5de4c5480511::UnflattenVectorArgs983     mpy::vector_args operator()(Arena& A, Slice<mpy::handle>& elements) {
984         if (!had_nested) {
985             auto args = elements.begin();
986             elements = Slice<mpy::handle>();
987             return mpy::vector_args(args, nargs, kwnames);
988         }
989         Slice<mpy::handle> args;
990         for (auto u : children) {
991             args.append(A, A.autorelease(u(elements)));
992         }
993         return mpy::vector_args(args.begin(), nargs, kwnames);
994     }
995     Slice<Unflatten> children;
996     Py_ssize_t nargs;
997     mpy::handle kwnames;
998     bool had_nested;
999 };
1000 
tree_flatten(Arena & A,mpy::vector_args args,Slice<mpy::handle> & flat_elements)1001 UnflattenVectorArgs tree_flatten(Arena& A, mpy::vector_args args, Slice<mpy::handle>& flat_elements) {
1002     UnflattenVectorArgs r;
1003     r.kwnames = args.kwnames;
1004     r.nargs = args.nargs;
1005     r.had_nested = false;
1006     auto N = args.size();
1007     for(auto i : irange(N)) {
1008         auto typ = Py_TYPE(args[i].ptr());
1009         // fast checks that this thing isn't something that is nested.
1010         bool is_element = !typ->tp_as_sequence ||  typ == torch_Tensor || typ == TensorType || typ == DimType;
1011         if (!is_element) {
1012             flat_elements.extend(A, args.args, args.args + i);
1013             for (auto j : irange(i)) {
1014                 (void)j;
1015                 r.children.append(A, Unflatten {U_ELEM});
1016             }
1017             for (auto j : irange(i, N)) {
1018                 r.children.append(A, tree_flatten(A, args[j], flat_elements));
1019                 if (r.children.back().type != U_ELEM) {
1020                     r.had_nested = true;
1021                 }
1022             }
1023             return r;
1024         }
1025     }
1026     flat_elements.extend(A, args.args, args.args + N);
1027     return r;
1028 }
1029 
1030 
1031 struct UnflattenArena {
1032     Arena A;
1033     Unflatten unflatten;
1034 };
1035 
py_unflatten(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1036 PyObject* py_unflatten(PyObject *self,
1037                       PyObject *const *args,
1038                       Py_ssize_t nargs,
1039                       PyObject *kwnames) {
1040     PY_BEGIN
1041     #define ARGS(_) _(mpy::handle, ns)
1042     MPY_PARSE_ARGS_KWNAMES("O", ARGS)
1043     #undef ARGS
1044     mpy::sequence_view sv(ns);
1045     // because we do not have a autorelase pool yet...
1046     Arena A;
1047     Slice<mpy::handle> slice;
1048     mpy::handle Tuple = (PyObject*) &PyTuple_Type;
1049     auto inputs = Tuple.call(ns);
1050     mpy::tuple_view tv(inputs);
1051     for (auto i : tv.enumerate()) {
1052         slice.append(A, tv[i]);
1053     }
1054     auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena");
1055     auto r = AA->unflatten(slice).release();
1056     AT_ASSERT(r != nullptr);
1057     return r;
1058     PY_END(nullptr)
1059 }
1060 
1061 PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS};
1062 
free_unflatten_arena(PyObject * pc)1063 void free_unflatten_arena(PyObject * pc) {
1064     delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena");
1065 }
1066 
py_tree_flatten(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1067 PyObject* py_tree_flatten(PyObject *self,
1068                       PyObject *const *args,
1069                       Py_ssize_t nargs,
1070                       PyObject *kwnames) {
1071     PY_BEGIN
1072     #define ARGS(_) _(mpy::handle, tree)
1073     MPY_PARSE_ARGS_KWNAMES("O", ARGS)
1074     #undef ARGS
1075     auto A = new UnflattenArena;
1076     Slice<mpy::handle> elements;
1077     A->unflatten = tree_flatten(A->A, tree, elements);
1078     auto cap = mpy::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena));
1079     auto unflatten = mpy::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release()));
1080     mpy::tuple r(2);
1081     r.set(0, slice_to_list(elements));
1082     r.set(1, std::move(unflatten));
1083     return r.release();
1084     PY_END(nullptr)
1085 }
1086 
1087 
1088 
tree_map(Arena & A,const std::function<mpy::handle (mpy::handle)> & fn,mpy::handle agg)1089 mpy::object tree_map(Arena& A, const std::function<mpy::handle(mpy::handle)>& fn, mpy::handle agg) {
1090     Slice<mpy::handle> elements;
1091     auto unflatten = tree_flatten(A, agg, elements);
1092     for (auto i : elements.enumerate()) {
1093         elements[i] = fn(elements[i]);
1094     }
1095     return unflatten(elements);
1096 }
1097 
1098 // prereq: isinstance(h, _Tensor)
_Tensor_ndim(mpy::handle h)1099 int64_t _Tensor_ndim(mpy::handle h) {
1100     if (Tensor::check(h)) {
1101         int64_t r = 0;
1102         for (auto l : Tensor::unchecked_wrap(h)->levels()) {
1103             if (l.is_positional()) {
1104                 ++r;
1105             }
1106         }
1107         return r;
1108     }
1109     // Dim or DelayedMulTensor
1110     return 0;
1111 }
1112 
handle_from_tensor(Arena & A,TensorRef t)1113 mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
1114     // fast case: tensor is live in python
1115     std::optional<PyObject*> mb_obj =
1116         t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
1117     if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
1118         return *mb_obj;
1119     }
1120     return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
1121 }
1122 }
1123 struct EnableAllLayers {
EnableAllLayersEnableAllLayers1124     EnableAllLayers(Arena& A, Slice<DimEntry> levels) {
1125         std::vector<std::pair<int64_t, int64_t>> layers;
1126         layers.reserve(levels.size());
1127         for (auto l : levels) {
1128             if (!l.is_positional()) {
1129                 auto d = l.dim();
1130                 levels_to_dim_.append(A, d);
1131             }
1132         }
1133         std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](mpy::hdl<Dim> lhs, mpy::hdl<Dim> rhs) { return lhs->level_ < rhs->level_;});
1134 
1135         for (auto i : levels_to_dim_.enumerate()) {
1136             auto batch_size = levels_to_dim_[i]->size();
1137             auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different);
1138             if (i == 0) {
1139                 levels_start_ = level;
1140             }
1141         }
1142     }
1143 
~EnableAllLayersEnableAllLayers1144     ~EnableAllLayers() {
1145         auto to_remove = levels_start_ + levels_to_dim_.size() - 1;
1146         for (auto i : levels_to_dim_.enumerate()) {
1147             AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i);
1148         }
1149     }
1150 
from_batchedEnableAllLayers1151     mpy::obj<Tensor> from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) {
1152         Slice<DimEntry> levels;
1153         for (auto i : irange(-batchedtensor.dim(), 0)) {
1154             levels.append(A, i);
1155         }
1156         TensorRef tensor;
1157         at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor);
1158         while(true) {
1159             auto level = impl->level();
1160             AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size());
1161             mpy::hdl<Dim> dim = levels_to_dim_[level - levels_start_].ptr();
1162             levels.insert(A, impl->bdim(), dim);
1163             at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value());
1164             if (!nimpl) {
1165                 tensor = impl->value();
1166                 break;
1167             }
1168             impl = nimpl;
1169         }
1170 
1171         mpy::obj<Tensor> self = Tensor::create();
1172         // grab ownership of the tensors
1173         self->tensor_ = *tensor;
1174         self->batchtensor_ = std::move(batchedtensor);
1175         self->has_device_ = has_device;
1176         self->capture_levels(levels);
1177         return self;
1178     }
inplace_update_layersEnableAllLayers1179     void inplace_update_layers(TensorRef batchtensor, Slice<DimEntry> levels) {
1180         // XXX - requires a patch to functorch to att set_level
1181         auto impl = maybeGetBatchedImpl(*batchtensor);
1182         for (auto i : levels_to_dim_.reversed_enumerate()) {
1183             if (!impl) {
1184                 break;
1185             }
1186             if (levels.contains(levels_to_dim_[i])) {
1187                 impl->_unsafe_set_level(levels_start_ + i);
1188                 impl = maybeGetBatchedImpl(impl->value());
1189 
1190             }
1191         }
1192     }
1193 private:
1194     int64_t levels_start_{};
1195     Slice<mpy::hdl<Dim>> levels_to_dim_;
1196 };
1197 
1198 namespace{
_match_levels(Arena & A,TensorRef v,Slice<DimEntry> from_levels,Slice<DimEntry> to_levels,bool drop_levels=false)1199 TensorRef _match_levels(Arena& A, TensorRef v, Slice<DimEntry> from_levels, Slice<DimEntry> to_levels, bool drop_levels=false) {
1200     if (from_levels == to_levels) {
1201         return v;
1202     }
1203     // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0.
1204     at::IntArrayRef sz = v->sizes();
1205     at::IntArrayRef sd = v->strides();
1206     AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size());
1207     Slice<int64_t> nsz;
1208     Slice<int64_t> nsd;
1209     for (auto l : to_levels) {
1210         auto oidx = from_levels.index(l);
1211         if (!oidx) {
1212             nsz.append(A, l.is_positional() ? 1 : l.dim()->size());
1213             nsd.append(A, 0);
1214         } else {
1215             auto idx = *oidx;
1216             nsz.append(A, sz[idx]);
1217             nsd.append(A, sd[idx]);
1218         }
1219     }
1220     return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset()));
1221 }
1222 }
run_torch_function(Arena & A,mpy::handle orig,mpy::vector_args args,bool is_pointwise)1223 mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) {
1224     if (!pointwise_optimize) {
1225         is_pointwise = false;
1226     }
1227     // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n";
1228 
1229     Slice<mpy::hdl<Dim>> all_dims;
1230     Slice<mpy::handle> flat_args;
1231     auto unflatten_args = tree_flatten(A, args, flat_args);
1232     TensorRef device_holding_tensor;
1233 
1234     Slice<TensorInfo> infos;
1235     Slice<DimEntry> result_levels;
1236     for (auto f : flat_args) {
1237         infos.append(A, TensorInfo::create(A, f, !is_pointwise, false));
1238         if (infos.back()) {
1239             TensorInfo& info = infos.back();
1240             AT_ASSERT(is_pointwise || info.batchedtensor);
1241             if (!device_holding_tensor && info.has_device) {
1242                 device_holding_tensor = infos.back().tensor;
1243             }
1244             for (auto l : info.levels) {
1245                 if (!result_levels.contains(l)) {
1246                     result_levels.append(A, l);
1247                 }
1248             }
1249         }
1250     }
1251 
1252     if (is_pointwise) {
1253         for (auto i : flat_args.enumerate()) {
1254             if (infos[i]) {
1255                 TensorRef tensor = infos[i].tensor;
1256                 if (device_holding_tensor && !infos[i].has_device) {
1257                     tensor = A.autorelease(tensor->to(device_holding_tensor->device()));
1258                 }
1259                 auto ml = _match_levels(A, tensor, infos[i].levels, result_levels);
1260                 flat_args[i] = handle_from_tensor(A, std::move(ml));
1261             }
1262         }
1263 
1264         Slice<mpy::handle> flat_it = flat_args;
1265         mpy::vector_args uargs = unflatten_args(A, flat_it);
1266 
1267         mpy::object result = orig.call_vector(uargs);
1268 
1269         // fast wrap for normal case where operator just returns a tensor.
1270         if (THPVariable_Check(result.ptr())) {
1271             return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor);
1272         }
1273         auto wrap = [&](mpy::handle h) {
1274             if (THPVariable_Check(h.ptr())){
1275                 return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor));
1276             }
1277             return h;
1278         };
1279         return tree_map(A, wrap, result);
1280     } else {
1281         // std::cout << orig << " calling functorch...\n";
1282         // std::cout << "rl: " << result_levels << "\n";
1283         EnableAllLayers guard(A, result_levels);
1284         for (auto i : flat_args.enumerate()) {
1285             if (infos[i]) {
1286                 TensorRef batched = infos[i].batchedtensor;
1287                 if (device_holding_tensor && !infos[i].has_device) {
1288                     batched = A.autorelease(batched->to(device_holding_tensor->device()));
1289                 }
1290                 guard.inplace_update_layers(batched, infos[i].levels);
1291                 flat_args[i] = handle_from_tensor(A, batched);
1292             }
1293         }
1294         Slice<mpy::handle> flat_it = flat_args;
1295         mpy::vector_args uargs = unflatten_args(A, flat_it);
1296         AT_ASSERT(flat_it.size() == 0);
1297         mpy::object result = orig.call_vector(uargs);
1298         auto wrap = [&](mpy::handle h) {
1299             if (THPVariable_Check(h.ptr())) {
1300                 return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor));
1301             }
1302             return h;
1303         };
1304         if (THPVariable_Check(result.ptr())) {
1305             return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor);
1306         }
1307         return tree_map(A, wrap, result);
1308     }
1309 }
1310 
1311 namespace{
1312 
__torch_function__(Arena & A,mpy::handle orig,mpy::vector_args args,bool is_pointwise)1313 mpy::object __torch_function__(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) {
1314     if (orig == torch_Tensor___mul__) {
1315         AT_ASSERT(args.nargs == 2 && !args.has_keywords());
1316         auto lhs = args[0];
1317         auto rhs = args[1];
1318         if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) {
1319             bool has_device = false;
1320             Slice<DimEntry> levels;
1321             for (auto i : args.enumerate_positional()) {
1322                 auto t = TensorInfo::create(A, args[i], false);
1323                 // something like a mask * rhs, which matrix multiplies don't correctly promote
1324                 if (!t.tensor->is_floating_point()) {
1325                     return run_torch_function(A, orig, args, is_pointwise);
1326                 }
1327                 has_device = has_device || t.has_device;
1328                 for (auto l : t.levels) {
1329                     if (!levels.contains(l)) {
1330                         levels.append(A, l);
1331                     }
1332                 }
1333             }
1334             // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n";
1335             return Tensor::create_delayed(mpy::object::borrow(orig), args, levels, has_device);
1336         }
1337     }
1338     return run_torch_function(A, orig, args, is_pointwise);
1339 }
1340 
as_vector_args(Arena & A,mpy::handle args,mpy::handle kwargs)1341 mpy::vector_args as_vector_args(Arena& A, mpy::handle args, mpy::handle kwargs) {
1342     auto pos_args = (mpy::handle*) &PyTuple_GET_ITEM(args.ptr(), 0);
1343     auto pos_n = PyTuple_GET_SIZE(args.ptr());
1344     if (!kwargs.ptr()) {
1345         return mpy::vector_args(pos_args, pos_n, nullptr);
1346     }
1347     Slice<mpy::handle> all_args;
1348     Slice<mpy::handle> kwnames;
1349     all_args.extend(A, pos_args, pos_args + pos_n);
1350     mpy::dict_view dv(kwargs);
1351     Py_ssize_t pos = 0;
1352     mpy::handle key, value;
1353     while (dv.next(&pos, &key, &value)) {
1354         all_args.append(A, value);
1355         kwnames.append(A, key);
1356     }
1357     return mpy::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames)));
1358 }
1359 
py___torch_function__(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1360 PyObject* py___torch_function__(PyObject *self,
1361                       PyObject *const *args,
1362                       Py_ssize_t nargs,
1363                       PyObject *kwnames) {
1364     Arena A;
1365     PY_BEGIN
1366     maybeInitializeGlobals();
1367     AT_ASSERT(nargs == 4 || nargs == 5);
1368     auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr);
1369     bool is_pointwise = pointwise.contains(args[1]);
1370     return __torch_function__(A, args[1], std::move(va), is_pointwise).release();
1371     PY_END(nullptr)
1372 }
1373 
levels_to_tuple(Slice<DimEntry> slice)1374 mpy::object levels_to_tuple(Slice<DimEntry> slice) {
1375     mpy::tuple t(slice.size());
1376     for (auto i : slice.enumerate()) {
1377         t.set(i, slice[i].is_positional() ?  mpy::from_int(slice[i].position()) : mpy::object::borrow(slice[i].dim()));
1378     }
1379     mpy::object r = std::move(t);
1380     return r;
1381 }
1382 
Tensor_ndim(Tensor * self,void *)1383 PyObject* Tensor_ndim(Tensor* self, void*) {
1384     Py_ssize_t i = 0;
1385     for (auto l : self->levels()) {
1386         if (l.is_positional()) {
1387             ++i;
1388         }
1389     }
1390     return mpy::from_int(i).release();
1391 }
1392 
1393 PyGetSetDef Tensor_getsetters[] = {
__anon5de4c5480b02() 1394    {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_bool(((Tensor*)self)->has_device()).release(); }, NULL},
__anon5de4c5480c02() 1395    {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* {
1396        Arena A;
1397        return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL},
__anon5de4c5480d02() 1398    {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* {
1399        Arena A;
1400        return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL},
__anon5de4c5480e02() 1401    {"_levels", (getter) [](PyObject* self, void*) -> PyObject* {
1402        PY_BEGIN
1403        return levels_to_tuple(((Tensor*)self)->levels()).release();
1404        PY_END(nullptr)
1405    }},
1406     {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL},
1407     {NULL}  /* Sentinel */
1408 };
1409 
1410 PyMethodDef Tensor_methods[] = {
1411     {NULL, NULL, 0, NULL}        /* Sentinel */
1412 };
1413 }
1414 
1415 
1416 PyTypeObject Tensor::Type = {
1417     PyVarObject_HEAD_INIT(NULL, 0)
1418     "_C.Tensor",               /* tp_name */
1419     sizeof(Tensor),               /* tp_basicsize */
1420     0,                              /* tp_itemsize */
1421     Tensor::dealloc_stub,      /* tp_dealloc */
1422     0,                              /* tp_vectorcall_offset */
1423     0,                              /* tp_getattr */
1424     0,                              /* tp_setattr */
1425     0,                              /* tp_as_async */
1426     0,           /* tp_repr */
1427     0,                 /* tp_as_number */
1428     0,                 /* tp_as_sequence */
1429     0,             /* tp_as_mapping */
1430     0,      /* tp_hash */
1431     0,                              /* tp_call */
1432     0,                              /* tp_str */
1433     0,                              /* tp_getattro */
1434     0,                              /* tp_setattro */
1435     0,                              /* tp_as_buffer */
1436     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */
1437     "Tensor Object",                   /* tp_doc */
1438     0,                              /* tp_traverse */
1439     0,                              /* tp_clear */
1440     0,  /* tp_richcompare */
1441     0,                              /* tp_weaklistoffset */
1442     0,                              /* tp_iter */
1443     0,                              /* tp_iternext */
1444     Tensor_methods,                /* tp_methods */
1445     0,                              /* tp_members */
1446     Tensor_getsetters,             /* tp_getset */
1447     0,                              /* tp_base */
1448     0,                              /* tp_dict */
1449     0,                              /* tp_descr_get */
1450     0,                              /* tp_descr_set */
1451     0,                              /* tp_dictoffset */
1452     0,            /* tp_init */
1453     0,                              /* tp_alloc */
1454     Tensor::new_stub,                      /* tp_new */
1455 };
1456 
1457 
1458 // dim() --------------------
1459 
relevant_op(_Py_CODEUNIT c)1460 static bool relevant_op(_Py_CODEUNIT c) {
1461     switch(c) {
1462         case STORE_NAME:
1463         case STORE_GLOBAL:
1464         case STORE_FAST:
1465         case STORE_DEREF:
1466             return true;
1467         default:
1468             return false;
1469     }
1470 }
1471 
create_dim(mpy::object name,mpy::handle size)1472 static mpy::object create_dim(mpy::object name, mpy::handle size) {
1473     auto d = Dim::create(std::move(name));
1474     if (!mpy::is_none(size)) {
1475         d->set_size(mpy::to_int(size));
1476     }
1477     return std::move(d);
1478 }
1479 
create_dimlist(mpy::object name,mpy::handle size)1480 static mpy::object create_dimlist(mpy::object name, mpy::handle size) {
1481     auto d = DimList::create(std::move(name));
1482     if (!mpy::is_none(size)) {
1483         if (mpy::is_int(size)) {
1484             d->bind_len(mpy::to_int(size));
1485         } else {
1486             mpy::sequence_view s(size);
1487             d->bind_len(s.size());
1488             for (auto i : irange(d->size())) {
1489                 d->dims_[i]->set_size(mpy::to_int(s[i]));
1490             }
1491         }
1492     }
1493     return std::move(d);
1494 }
1495 
1496 
1497 
1498 // Python wrappers that make new reflection primitives available for older runtimes
1499 #if !(IS_PYTHON_3_11_PLUS)
1500 #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code))
1501 #endif
1502 
1503 namespace{
1504 struct PyInstDecoder {
PyInstDecoder__anon5de4c5480f11::PyInstDecoder1505     PyInstDecoder(PyCodeObject* code_object, int lasti)
1506     : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT))  {}
1507     // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols
1508     // See https://github.com/pytorch/pytorch/issues/93854
next__anon5de4c5480f11::PyInstDecoder1509     void next() {
1510     #if IS_PYTHON_3_11_PLUS
1511         offset_ += _PyOpcode_Caches[opcode()];
1512     #endif
1513         offset_ += 1;
1514     }
opcode__anon5de4c5480f11::PyInstDecoder1515     int opcode() {
1516         auto r = _Py_OPCODE(code_[offset_]);
1517     #if IS_PYTHON_3_11_PLUS
1518         r = _PyOpcode_Deopt[r];
1519     #endif
1520         return r;
1521     }
oparg__anon5de4c5480f11::PyInstDecoder1522     int oparg() {
1523         return _Py_OPARG(code_[offset_]);
1524     }
1525 
name__anon5de4c5480f11::PyInstDecoder1526     mpy::object name() {
1527         mpy::object names;
1528         switch(opcode()) {
1529             case STORE_NAME:
1530             case STORE_GLOBAL:
1531                 names = mpy::object::borrow(code_object_->co_names);
1532                 break;
1533             case STORE_FAST:
1534                 names = mpy::object::steal(PyCode_GetVarnames(code_object_));
1535                 break;
1536             case STORE_DEREF:
1537                 names = mpy::object::steal(PyCode_GetCellvars(code_object_));
1538                 break;
1539             default:
1540                 return mpy::object();
1541         }
1542         return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg()));
1543     }
1544 private:
1545     PyCodeObject* code_object_;
1546     _Py_CODEUNIT* code_;
1547     int offset_;
1548 };
1549 
1550 template<mpy::object (*create_object)(mpy::object, mpy::handle)>
_dims(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1551 static PyObject* _dims(PyObject *self,
1552                       PyObject *const *args,
1553                       Py_ssize_t nargs,
1554                       PyObject *kwnames) {
1555     PY_BEGIN
1556     Py_ssize_t specified_ndims = -1;
1557     Py_ssize_t found_ndims = 0;
1558     Py_ssize_t sizes = -1;
1559     mpy::handle n = Py_None;
1560     mpy::handle py_sizes = Py_None;
1561 
1562     if (nargs || kwnames) {
1563         mpy::vector_args va(args, nargs, kwnames);
1564         va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0);
1565         if (!mpy::is_none(py_sizes)) {
1566             sizes = mpy::sequence_view(py_sizes).size();
1567             specified_ndims = sizes;
1568         }
1569         if (!mpy::is_none(n)) {
1570             specified_ndims = mpy::to_int(n);
1571         }
1572     }
1573 
1574     PyThreadState* state = PyThreadState_GET();
1575     auto f = mpy::obj<PyFrameObject>::steal(PyThreadState_GetFrame(state));
1576     auto c = mpy::obj<PyCodeObject>::steal(PyFrame_GetCode(f.ptr()));
1577     auto lasti = PyFrame_GetLasti(f.ptr());
1578     auto decoder = PyInstDecoder(c.ptr(), lasti);
1579     #if IS_PYTHON_3_11_PLUS
1580     // When py3.11 adapts bytecode lasti points to the precall
1581     // rather than the call instruction after it
1582     if (decoder.opcode() == PRECALL) {
1583         decoder.next();
1584     }
1585     #endif
1586     decoder.next();
1587 
1588     if (relevant_op(decoder.opcode())) {
1589         found_ndims = 1;
1590     } else if (decoder.opcode() == UNPACK_SEQUENCE) {
1591         found_ndims = decoder.oparg();
1592         decoder.next();
1593     }
1594 
1595     if (specified_ndims == -1) {
1596         if (found_ndims == 0) {
1597             mpy::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified");
1598         }
1599         specified_ndims = found_ndims;
1600     }
1601     if (found_ndims != specified_ndims) {
1602         found_ndims = 0; // avoid taking the wrong names for dimensions
1603     }
1604 
1605     auto genobject = [&](int i) -> mpy::object {
1606         mpy::object name;
1607         if (i < found_ndims) {
1608             name = decoder.name();
1609         }
1610         if (!name.ptr()) {
1611             name = mpy::unicode_from_format("d%d", i);
1612             found_ndims = 0; // once we fail at finding a name, we can find any more
1613         } else {
1614             decoder.next();
1615         }
1616         return create_object(std::move(name), sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None));
1617     };
1618     if (sizes != -1 && sizes != specified_ndims) {
1619         mpy::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes));
1620     }
1621     if (specified_ndims == 1) {
1622         return genobject(0).release();
1623     }
1624     mpy::tuple result(specified_ndims);
1625     for (int i = 0; i < specified_ndims; ++i) {
1626         result.set(i, genobject(i));
1627     }
1628     return result.release();
1629     PY_END(nullptr)
1630 }
1631 
1632 struct DotPart {
1633     Slice<DimEntry> dims;
1634     size_t total_size = 1;
append__anon5de4c5480f11::DotPart1635     void append(Arena& A, mpy::hdl<Dim> d) {
1636         total_size *= d->size();
1637         dims.append(A, d);
1638     }
1639 };
1640 
1641 template<typename T>
as_array_ref(Slice<T> t)1642 static at::ArrayRef<T> as_array_ref(Slice<T> t) {
1643     return at::ArrayRef<T>(t.begin(), t.end());
1644 }
1645 
dot_prepare(Arena & A,std::initializer_list<DotPart> parts,const TensorInfo & t)1646 static TensorRef dot_prepare(Arena& A, std::initializer_list<DotPart> parts, const TensorInfo& t) {
1647     Slice<DimEntry> new_levels;
1648     bool needs_reshape = false;
1649     for (auto p : parts) {
1650         if (p.dims.size() != 1) {
1651             needs_reshape = true;
1652         }
1653         new_levels.extend(A, p.dims);
1654     }
1655     auto r = _match_levels(A, t.tensor, t.levels, new_levels, true);
1656     if (!needs_reshape) {
1657         return r;
1658     }
1659     Slice<int64_t> view;
1660     for (auto p : parts) {
1661         view.append(A, p.total_size);
1662     }
1663     return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end())));
1664 }
1665 
dot_finish(Arena & A,std::initializer_list<DotPart> parts,at::Tensor r)1666 static mpy::object dot_finish(Arena& A, std::initializer_list<DotPart> parts, at::Tensor r) {
1667     Slice<DimEntry> result_levels;
1668     bool needs_reshape = false;
1669     for (auto p : parts) {
1670         if (p.dims.size() != 1) {
1671             needs_reshape = true;
1672         }
1673         result_levels.extend(A, p.dims);
1674     }
1675     if (needs_reshape) {
1676         Slice<int64_t> new_size;
1677         for (auto l : result_levels) {
1678             new_size.append(A, l.dim()->size());
1679         }
1680         r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end()));
1681     }
1682     return Tensor::from_positional(A, std::move(r), result_levels, true);
1683 }
1684 
1685 
1686 
dot(Arena & A,TensorInfo lhs,TensorInfo rhs,Slice<DimEntry> sum)1687 static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice<DimEntry> sum) {
1688     auto lhs_strides = lhs.tensor->strides();
1689     auto rhs_strides = rhs.tensor->strides();
1690 
1691     DotPart lro_dims;
1692     DotPart lo_dims;
1693     DotPart ro_dims;
1694     DotPart lr_dims;
1695 
1696     auto insert_dim = [&] (mpy::hdl<Dim> d, std::optional<int> lhs_idx, std::optional<int> rhs_idx) {
1697         bool reduced = sum.contains(d);
1698         int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0;
1699         int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0;
1700         if (reduced) {
1701             // lr
1702             lr_dims.append(A, d);
1703         } else {
1704             if ((lhs_stride == 0) == (rhs_stride == 0)) {
1705                 // lro
1706                 lro_dims.append(A, d);
1707             } else if (lhs_stride != 0) {
1708                 // lo
1709                 lo_dims.append(A, d);
1710             } else {
1711                 AT_ASSERT(rhs_stride != 0);
1712                 ro_dims.append(A, d);
1713             }
1714         }
1715     };
1716 
1717 
1718     auto rhs_seen = A.allocate<bool>(rhs.levels.size());
1719     std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false);
1720 
1721     for (auto i : lhs.levels.enumerate()) {
1722         auto d = lhs.levels[i];
1723         auto rhs_idx = rhs.levels.index(d);
1724         if (rhs_idx) {
1725             rhs_seen[*rhs_idx] = true;
1726         }
1727         insert_dim(d.dim(), i, rhs_idx);
1728     }
1729 
1730     for (auto i : rhs.levels.enumerate()) {
1731         if (rhs_seen[i]) {
1732             continue;
1733         }
1734         auto d = rhs.levels[i];
1735         insert_dim(d.dim(), std::nullopt, i);
1736     }
1737 
1738     if (lr_dims.dims.size() != sum.size()) {
1739         for (auto & d : sum) {
1740             if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) {
1741                 mpy::raise_error(DimensionBindError(), "summing over non-existant dimension %S", d.dim().ptr());
1742             }
1743         }
1744     }
1745 
1746     // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n";
1747     // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n";
1748 
1749     // no batch, just call mm
1750     if (lro_dims.dims.size() != 0) {
1751         auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs);
1752         auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs);
1753         return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_));
1754     } else {
1755         auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs);
1756         auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs);
1757         return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_));
1758     }
1759 
1760 }
1761 
test_c(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1762 static PyObject* test_c(PyObject *self,
1763                       PyObject *const *args,
1764                       Py_ssize_t nargs,
1765                       PyObject *kwnames) {
1766     PY_BEGIN
1767 
1768     Arena A;
1769     Slice<int> s(A, 3, 4, 5);
1770     AT_ASSERT(s.size() == 3 && s.capacity() == 8);
1771     AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5);
1772     s.append(A, 6);
1773     AT_ASSERT(s[3] == 6);
1774     for(int i : irange(10)) {
1775         s.append(A, i);
1776     }
1777     AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16);
1778 
1779     Slice<int> s2(A, -1, -2, -3);
1780     AT_ASSERT(s2[1] == -2 && s[0] == 3);
1781 
1782     auto ss = s.slice(1,2);
1783     AT_ASSERT(ss.size() == 1);
1784     AT_ASSERT(ss[0] == 4);
1785     AT_ASSERT(ss.capacity() == 1);
1786     ss.append(A, -4);
1787     AT_ASSERT(ss.size() == 2 && ss[1] == -4);
1788     ss[0] = 3;
1789     AT_ASSERT(s[1] == 4);
1790 
1791     s.insert(A, s.slice(1, 4), ss);
1792     AT_ASSERT(s[1] == 3  && s[2] == -4 && s[3] == 0);
1793 
1794     auto sz = s.size();
1795     s.insert(A, s.slice(1, 1), 4);
1796     AT_ASSERT(s[1] == 4 && sz + 1 == s.size());
1797 
1798 
1799     Slice<int> d(A, 0, 1, 2, 3, 4);
1800 
1801     Slice<int> b(A, 0, 1, 2, 3, 4);
1802     b.insert(A, b.slice(1,1), d);
1803     AT_ASSERT(b.size() == 10);
1804     AT_ASSERT(b[1] == 0);
1805     AT_ASSERT(b[5] == 4);
1806     AT_ASSERT(b.back() == 4);
1807 
1808     Py_RETURN_NONE;
1809 
1810     PY_END(nullptr);
1811 }
1812 
1813 
order(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1814 static PyObject* order(PyObject *_,
1815                       PyObject *const *args,
1816                       Py_ssize_t nargs,
1817                       PyObject *kwnames) {
1818     Arena A;
1819     PY_BEGIN
1820     if (kwnames) {
1821         mpy::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames);
1822     }
1823     AT_ASSERT(nargs-- > 0);
1824     Slice<DimEntry> orig_levels;
1825     Slice<DimEntry> levels;
1826     TensorRef data;
1827     mpy::handle self = args++[0];
1828     bool has_device;
1829     if (Tensor::check_exact(self)) {
1830         auto t = Tensor::unchecked_wrap(self);
1831         orig_levels = t->levels();
1832         data = t->tensor(A);
1833         has_device = t->has_device();
1834     } else {
1835        auto d = Dim::unchecked_wrap(self);
1836         orig_levels.append(A, d);
1837         data = d->range();
1838         has_device = false;
1839     }
1840 
1841     Slice<DimEntry> flat_positional_dims;
1842     Slice<std::pair<int, int>> to_flatten;
1843     levels.extend(A, orig_levels);
1844 
1845     int orig_ndim = ndim_of_levels(levels);
1846     auto append = [&](DimEntry d) {
1847         auto midx = levels.index(d);
1848         if (!midx) {
1849             if (d.is_positional()) {
1850                 mpy::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim));
1851             } else {
1852                 mpy::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr());
1853             }
1854         }
1855         levels[*midx] = DimEntry();
1856         flat_positional_dims.append(A, d);
1857     };
1858 
1859     int n_new_positional = 0;
1860     for (auto i :irange(nargs)) {
1861         mpy::handle arg  = args[i];
1862         DimEntry entry = _wrap_dim(arg, orig_ndim, false);
1863         if (!entry.is_none()) {
1864             append(entry);
1865             ++n_new_positional;
1866         } else if (DimList::check(arg)) {
1867             auto dl = DimList::unchecked_wrap(arg);
1868             for (mpy::obj<Dim> & d : dl->dims_) {
1869                 append(mpy::hdl<Dim>(d));
1870                 ++n_new_positional;
1871             }
1872         } else {
1873             ++n_new_positional;
1874             if (!mpy::is_sequence(arg)) {
1875                 mpy::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]");
1876             }
1877             mpy::sequence_view sq(arg);
1878             auto N = sq.size();
1879             to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N));
1880             for (auto j : irange(N)) {
1881                 DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false);
1882                 if (e.is_none()) {
1883                     mpy::raise_error(PyExc_ValueError, "expected a Dim, or int");
1884                 }
1885                 append(e);
1886             }
1887         }
1888     }
1889 
1890     int ndim = 0;
1891     int insert_point = -1;
1892     Slice<DimEntry> new_levels;
1893     for (auto l : levels) {
1894         if (l.is_none()) {
1895             continue;
1896         }
1897         if (l.is_positional()) {
1898             ndim++;
1899             if (insert_point == -1) {
1900                 insert_point = new_levels.size();
1901                 new_levels.extend(A, flat_positional_dims);
1902             }
1903         }
1904         new_levels.append(A, l);
1905     }
1906     if (insert_point == -1) {
1907         insert_point = new_levels.size();
1908         new_levels.extend(A, flat_positional_dims);
1909     }
1910 
1911     at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels);
1912 
1913     if (to_flatten.size()) {
1914         Slice<int64_t> view;
1915         auto sz = ndata.sizes();
1916         // before the new positional dims
1917         for (auto i : irange(0, insert_point)) {
1918             view.append(A, sz[i]);
1919         }
1920         int i = 0;
1921         for (auto to_flat : to_flatten) {
1922             for (;i < to_flat.first; ++i) {
1923                 view.append(A, sz[insert_point + i]);
1924             }
1925             int64_t new_size = 1;
1926             int last = i + to_flat.second;
1927             for (; i < last; ++i) {
1928                 new_size *= sz[insert_point + i];
1929             }
1930             view.append(A, new_size);
1931         }
1932         for (; i < flat_positional_dims.size(); ++i) {
1933             view.append(A, sz[insert_point + i]);
1934         }
1935         // after the new positional dims
1936         for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) {
1937             view.append(A, sz[i]);
1938         }
1939         // we shorted the number of dimension, so remove them from new levels
1940         // we will renumber them later
1941         auto n_to_remove = flat_positional_dims.size() - n_new_positional;
1942         new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice<DimEntry>());
1943         ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end()));
1944     }
1945 
1946     // renumber the positional dimension
1947     int seen = 0;
1948     for (auto i : new_levels.reversed_enumerate()) {
1949         if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) {
1950             new_levels[i] = --seen;
1951         }
1952     }
1953     return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release();
1954 
1955     PY_END(nullptr)
1956 }
1957 
expand(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1958 static PyObject* expand(PyObject *_,
1959                       PyObject *const *args,
1960                       Py_ssize_t nargs,
1961                       PyObject *kwnames) {
1962     Arena A;
1963     PY_BEGIN
1964     AT_ASSERT(nargs-- > 0);
1965     auto info = TensorInfo::create(A, args++[0], false);
1966     for (auto i : irange(nargs)) {
1967         if (!Dim::check(args[i])) {
1968             maybeInitializeGlobals();
1969             mpy::vector_args vargs(args - 1, nargs + 1, kwnames);
1970             if (THPVariable_Check(args[-1])) {
1971                 return torch_Tensor_expand.call_vector(vargs).release();
1972             } else {
1973                 return __torch_function__(A, torch_Tensor_expand, vargs, false).release();
1974             }
1975         }
1976     }
1977     const at::Tensor& data = *info.tensor;
1978     auto levels = info.levels;
1979     Slice<DimEntry> new_levels;
1980     Slice<int64_t> sz;
1981     Slice<int64_t> sd;
1982     for (auto i : irange(nargs)) {
1983         auto d = Dim::unchecked_wrap(args[i]);
1984         if (levels.contains(d) || new_levels.contains(d)) {
1985             mpy::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr());
1986         }
1987         new_levels.append(A, d);
1988         sz.append(A, d->size());
1989         sd.append(A, 0);
1990     }
1991     new_levels.extend(A, levels);
1992     at::IntArrayRef osz = data.sizes();
1993     at::IntArrayRef osd = data.strides();
1994     sz.extend(A, osz.begin(), osz.end());
1995     sd.extend(A, osd.begin(), osd.end());
1996     at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset());
1997     return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release();
1998     PY_END(nullptr)
1999 }
2000 
2001 
_bind_dims_to_size(Arena & A,int64_t sz,int64_t sd,Slice<mpy::hdl<Dim>> dims,Slice<int64_t> & nsz,Slice<int64_t> & nsd)2002 static void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd,
2003                         Slice<mpy::hdl<Dim>> dims, Slice<int64_t>& nsz, Slice<int64_t>& nsd) {
2004     int64_t rhs_prod = 1;
2005     for (auto i : dims.enumerate()) {
2006         if (!dims[i]->is_bound()) {
2007             for (auto j : irange(i + 1, dims.size())) {
2008                 if (!dims[j]->is_bound()) {
2009                     mpy::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr());
2010                 }
2011                 rhs_prod *= dims[j]->size();
2012             }
2013             if (sz % rhs_prod != 0) {
2014                 mpy::tuple tup(dims.size());
2015                 for (auto j : dims.enumerate()) {
2016                     tup.set(j, dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) : mpy::unicode_from_string("?"));
2017                 }
2018                 mpy::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr());
2019             }
2020             int64_t inferred_size = sz / rhs_prod;
2021             dims[i]->set_size(inferred_size);
2022             rhs_prod = sz;
2023             break;
2024         }
2025         rhs_prod *= dims[i]->size();
2026     }
2027     if (rhs_prod != sz) {
2028         mpy::tuple tup(dims.size());
2029         for (auto j : dims.enumerate()) {
2030             tup.set(j, mpy::object::borrow(dims[j]));
2031         }
2032         mpy::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr());
2033     }
2034     auto new_strides = A.allocate<int64_t>(dims.size());
2035     auto prev_stride = sd;
2036     for (auto i : dims.reversed_enumerate()) {
2037         new_strides[i] = prev_stride;
2038         prev_stride = dims[i]->size()*prev_stride;
2039     }
2040     for (auto i : dims.enumerate()) {
2041         nsd.append(A, new_strides[i]);
2042         nsz.append(A, dims[i]->size());
2043     }
2044 }
2045 
has_dims(mpy::handle d)2046 static bool has_dims(mpy::handle d) {
2047     return Dim::check_exact(d) || Tensor::check_exact(d);
2048 }
2049 
2050 struct IndexingInfo {
2051     bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling
2052     bool advanced_indexing; // requires actual lookup
2053     TensorRef self;
2054     Slice<mpy::handle> flat_inputs;
2055     Slice<DimEntry> result_levels;
2056     bool has_device;
2057 };
2058 }
2059 
2060 IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<mpy::handle> input, Slice<DimEntry> keys, Slice<mpy::handle> values, bool has_dimpacks_or_none);
2061 namespace{
as_slice(mpy::tuple_view tv)2062 Slice<mpy::handle> as_slice(mpy::tuple_view tv) {
2063     PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0);
2064     return Slice<mpy::handle>((mpy::handle*)begin, (mpy::handle*) (begin + tv.size()));
2065 }
2066 
as_slice(mpy::list_view tv)2067 Slice<mpy::handle> as_slice(mpy::list_view tv) {
2068     PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0);
2069     return Slice<mpy::handle>((mpy::handle*)begin, (mpy::handle*) (begin + tv.size()));
2070 }
2071 
2072 
maybe_dimpack(Slice<mpy::handle> & elements,mpy::handle s,bool check_first=true)2073 bool maybe_dimpack(Slice<mpy::handle>& elements, mpy::handle s, bool check_first=true) {
2074     // can we avoid rechecking?
2075     if (mpy::list_view::check(s)) {
2076         mpy::list_view tv(s);
2077         if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) {
2078             elements = as_slice(tv);
2079             return true;
2080         }
2081     }
2082     // can we avoid rechecking?
2083     if (mpy::tuple_view::check(s)) {
2084         mpy::tuple_view tv(s);
2085         if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) {
2086             elements = as_slice(tv);
2087             return true;
2088         }
2089     }
2090     return false;
2091 };
2092 
is_dimpack(mpy::handle s)2093 bool is_dimpack(mpy::handle s) {
2094     Slice<mpy::handle> e;
2095     return maybe_dimpack(e, s);
2096 }
2097 
invoke_getitem(Arena & A,const IndexingInfo & iinfo)2098 mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) {
2099     at::Tensor rtensor;
2100     if (iinfo.advanced_indexing) {
2101         auto self_hdl = handle_from_tensor(A, iinfo.self);
2102         auto tup = slice_to_tuple(iinfo.flat_inputs);
2103         // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n";
2104         auto pytensor = mpy::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr()));
2105         rtensor = THPVariable_Unpack(pytensor.ptr());
2106     } else {
2107         // std::cout << "skipping original getindex\n";
2108         rtensor = *iinfo.self;
2109     }
2110     // std::cout << "returning (from_positional)\n";
2111     return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device);
2112 }
2113 
index(Arena & A,mpy::handle self,mpy::handle dims,mpy::handle indices)2114 mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indices) {
2115     maybeInitializeGlobals();
2116     Slice<mpy::handle> dims_list;
2117     Slice<mpy::handle> indices_list;
2118     // we allow for matching single dims to multiple dims,
2119     // so we first have to normalize everything into the case where there is a list on lhs and the rhs
2120     bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims);
2121     bool rhs_list = mpy::tuple_view::check(indices) || mpy::list_view::check(indices);
2122     if (lhs_list && rhs_list) {
2123         mpy::sequence_view dv(dims);
2124         mpy::sequence_view ind(indices);
2125         Py_ssize_t N = dv.size();
2126         if (N != ind.size()) {
2127             mpy::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size()));
2128         }
2129         for (auto i : irange(N)) {
2130             dims_list.append(A, A.autorelease(dv[i]));
2131             indices_list.append(A, A.autorelease(ind[i]));
2132         }
2133     } else {
2134         dims_list.append(A, dims);
2135         indices_list.append(A, indices);
2136     }
2137 
2138     // dims being indexed can be grouped together into a single index space, and we have to
2139     // flatten them int a single dimension before we can index them...
2140     auto self_info = TensorInfo::create(A, self, false);
2141     auto ndim = self_info.ndim();
2142     Slice<DimEntry> new_levels;
2143     Slice<DimEntry> to_flatten;
2144     Slice<DimEntry> dims_list_flat;
2145     auto parse_dim_entry = [&](mpy::handle s) -> DimEntry {
2146         auto d = _wrap_dim(s, ndim, false);
2147         if (d.is_none()) {
2148             mpy::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr());
2149         }
2150         return d;
2151     };
2152     auto dim_not_present = [&](DimEntry d) {
2153         if (d.is_positional()) {
2154             mpy::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim);
2155         } else {
2156             mpy::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr());
2157         }
2158     };
2159 
2160     for (auto i : dims_list.enumerate()) {
2161         Slice<mpy::handle> m;
2162         if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) {
2163             if (m.size() == 0) {
2164                 // plausible semantics work for this to have 0 elements (e.g. the index will always be 0)
2165                 dims_list_flat.append(A, DimEntry()); // value is just dropped
2166             }
2167             auto first = parse_dim_entry(m[0]);
2168             dims_list_flat.append(A, first);
2169             if (m.size() == 1) {
2170                 continue;
2171             }
2172             if (to_flatten.size() == 0) {
2173                 new_levels.extend(A, self_info.levels);
2174             }
2175             Slice<DimEntry> rest;
2176             for (auto i : irange(1, m.size())) {
2177                 auto d = parse_dim_entry(m[i]);
2178                 if (!new_levels.remove(A, d)) {
2179                     dim_not_present(d);
2180                 }
2181                 rest.append(A, d);
2182             }
2183 
2184             auto first_idx = new_levels.index(first);
2185             if (!first_idx) {
2186                 dim_not_present(first);
2187             }
2188             new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest);
2189             to_flatten.extend(A, rest);
2190         } else {
2191             dims_list_flat.append(A, parse_dim_entry(dims_list[i]));
2192         }
2193     }
2194     if (to_flatten.size() > 0) {
2195         TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels);
2196         at::IntArrayRef sizes = rearranged->sizes();
2197         Slice<int64_t> new_sizes;
2198         Slice<DimEntry> reshape_levels;
2199         for (auto i : new_levels.enumerate()) {
2200             if (to_flatten.contains(new_levels[i])) {
2201                 new_sizes.back() *= sizes[i];
2202             } else {
2203                 new_sizes.append(A, sizes[i]);
2204                 reshape_levels.append(A, new_levels[i]);
2205             }
2206         }
2207         self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end())));
2208 
2209         self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op
2210                                            // we need to be careful not to rely the dimensions size because it doesnt match the size of the whole group
2211     }
2212     bool has_dimpacks = false;
2213     for (auto idx : indices_list) {
2214         if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) {
2215             has_dimpacks = true;
2216             break;
2217         }
2218     }
2219     IndexingInfo info = getsetitem_flat(A, self_info, Slice<mpy::handle>(), dims_list_flat, indices_list, has_dimpacks);
2220     return invoke_getitem(A, info);
2221 }
2222 
2223 // true -- the indices were flattend out of a tuple, list or sequence...
2224 
slice_from_sequence(Arena & A,mpy::handle value)2225 Slice<mpy::handle> slice_from_sequence(Arena& A, mpy::handle value) {
2226     if (mpy::tuple_view::check(value)) {
2227         return as_slice(mpy::tuple_view(value));
2228     } else if (mpy::list_view::check(value)) {
2229         return as_slice(mpy::list_view(value));
2230     } else {
2231         mpy::sequence_view sv(value);
2232         Slice<mpy::handle> r;
2233         for (auto i : sv.enumerate()) {
2234             r.append(A, A.autorelease(sv[i]));
2235         }
2236         return r;
2237     }
2238 }
2239 
extractIndices(Arena & A,mpy::handle index,Slice<mpy::handle> & indices)2240 bool extractIndices(Arena& A, mpy::handle index, Slice<mpy::handle>& indices) {
2241     if (mpy::tuple_view::check(index)) {
2242         indices.extend(A, as_slice(mpy::tuple_view(index)));
2243         return true;
2244     } else if (THPVariable_Check(index.ptr())) {
2245         indices.append(A, index);
2246         return false;
2247     } else if (!mpy::is_sequence(index)) {
2248         indices.append(A, index);
2249         return false;
2250     }
2251     // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors..
2252     mpy::sequence_view sv(index);
2253     if (sv.size() >= 32) {
2254         indices.extend(A, slice_from_sequence(A, index));
2255         return true;
2256     }
2257     for (auto i : sv.enumerate()) {
2258         mpy::handle item;
2259         try {
2260             item = sv[i];
2261         } catch (mpy::exception_set & e) {
2262             PyErr_Clear();
2263             indices.append(A, index);
2264             return false;
2265         }
2266         if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || mpy::is_none(item) || has_dims(item)) {
2267             indices.extend(A, slice_from_sequence(A, index));
2268             return true;
2269         }
2270     }
2271     indices.append(A, index);
2272     return false;
2273 }
2274 
getsetitem(Arena & A,mpy::handle self,mpy::handle index,bool tensors_have_dims)2275 IndexingInfo getsetitem(Arena & A, mpy::handle self, mpy::handle index, bool tensors_have_dims) {
2276     bool can_call_original_getitem = !tensors_have_dims;
2277 
2278     Slice<mpy::handle> input;
2279     if (has_dims(index)) {
2280         input.append(A, index);
2281     } else {
2282         bool is_sequence = extractIndices(A, index, input);
2283         // nothing about first class dims here, fallback to getitem
2284         if (can_call_original_getitem && !is_sequence) {
2285             return { true };
2286         }
2287     }
2288 
2289     int64_t dims_indexed = 0;
2290     int64_t expanding_object = -1;
2291     DimList* unbound_dim_list = nullptr;
2292     auto check_expanding = [&](int64_t i) {
2293         if (expanding_object != -1) {
2294             mpy::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i);
2295         }
2296         expanding_object = i;
2297     };
2298     Slice<int64_t> dimlists;
2299 
2300     // calculate how many dimensioned have been indexed in order to compute the size of ...
2301     // or expand a potentially unbound dimension list.
2302 
2303     bool has_dimpacks_or_none = false;
2304     for (auto i : input.enumerate()) {
2305         mpy::handle s = input[i];
2306         if (Dim::check_exact(s) || Tensor::check_exact(s)) {
2307             can_call_original_getitem = false;
2308             ++dims_indexed;
2309         } else if (s.ptr() == Py_Ellipsis) {
2310             check_expanding(i);
2311         } else if (DimList::check(s)) {
2312             can_call_original_getitem = false;
2313             auto dl = DimList::unchecked_wrap(s);
2314             if (!dl->is_bound()) {
2315                 check_expanding(i);
2316                 unbound_dim_list = dl.ptr();
2317             } else {
2318                 dims_indexed += dl->dims_.size();
2319             }
2320             dimlists.append(A, i);
2321         } else if (mpy::is_none(s)) {
2322             has_dimpacks_or_none = true;
2323         } else if (is_dimpack(s)) {
2324             can_call_original_getitem = false;
2325             has_dimpacks_or_none = true;
2326             ++dims_indexed;
2327         } else {
2328             ++dims_indexed;
2329         }
2330     }
2331 
2332     // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem.
2333     if (can_call_original_getitem) {
2334         return {true};
2335     }
2336 
2337     // std::cout << "__getitem__ " << self << " " << index << "\n";
2338 
2339     TensorInfo self_info = TensorInfo::create(A, self, false, true);
2340     auto ndim = self_info.ndim();
2341     if (dims_indexed > ndim) {
2342         mpy::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim);
2343     }
2344     // expand any unbound dimension list, or expand ... into individual : slices.
2345     auto expanding_dims = ndim - dims_indexed;
2346     if (expanding_object != -1) {
2347         if (unbound_dim_list) {
2348             unbound_dim_list->bind_len(expanding_dims);
2349         } else {
2350             // ...
2351             Slice<mpy::handle> no_slices;
2352             for (auto i : irange(expanding_dims)) {
2353                 (void) i;
2354                 no_slices.append(A, no_slice);
2355             }
2356             input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices);
2357         }
2358     }
2359 
2360     // flatten out any dimensions stored in dimlist elements directly into the inputs
2361     // std::cout << dimlists << " <- dim lists!\n";
2362     for (int64_t i = dimlists.size() - 1; i >=0; --i) {
2363         auto idx = dimlists[i];
2364         // we added more elements to input because of ...
2365         // so we need to also adjust the index to get back to where the
2366         // dimlist existed
2367         if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) {
2368             idx += expanding_dims;
2369         }
2370         auto dl = DimList::unchecked_wrap(input[idx]);
2371         // XXX would be better if we used an OwnedSlice in DimList
2372         Slice<mpy::handle> more_dims((mpy::handle*) &*dl->dims_.begin(), (mpy::handle*) &*dl->dims_.end());
2373         input.insert(A, input.slice(idx, idx + 1), more_dims);
2374     }
2375 
2376     return getsetitem_flat(A, self_info, input, Slice<DimEntry>(), Slice<mpy::handle>(), has_dimpacks_or_none);
2377 }
2378 }
getsetitem_flat(Arena & A,TensorInfo self_info,Slice<mpy::handle> input,Slice<DimEntry> keys,Slice<mpy::handle> values,bool has_dimpacks_or_none)2379 IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<mpy::handle> input, Slice<DimEntry> keys, Slice<mpy::handle> values, bool has_dimpacks_or_none) {
2380     // At this point:
2381     // ..., DimList have been eliminated
2382     // Dim, Tensor, Tuple[Dim,...], int, slice still remain
2383 
2384 
2385     // we have to count how many times we see a dimension.
2386     // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing.
2387     Slice<mpy::hdl<Dim>> seen_dims;
2388     Slice<int64_t> seen_dims_nuses;
2389     auto add_dim = [&](mpy::hdl<Dim> entry) {
2390         auto midx = seen_dims.index(entry);
2391         if (!midx) {
2392             seen_dims.append(A, entry);
2393             seen_dims_nuses.append(A, 1);
2394         } else {
2395             ++seen_dims_nuses[*midx];
2396         }
2397     };
2398 
2399     Slice<mpy::handle> input_it = input;
2400 
2401     Slice<mpy::handle> flat_inputs;
2402     // flat inputs will start with an empty mpy::handle if the
2403     // actual value is in the tensor-like object in the tensor info
2404     Slice<TensorInfo> tensor_inputs;
2405 
2406     auto append_flat_handle = [&](mpy::handle h) {
2407         flat_inputs.append(A, h);
2408         tensor_inputs.append(A, TensorInfo());
2409     };
2410     TensorRef device_holding_tensor;
2411     auto append_tensor_input = [&](TensorInfo ti) {
2412         flat_inputs.append(A, mpy::handle());
2413         tensor_inputs.append(A, ti);
2414         if (ti.has_device && !device_holding_tensor) {
2415             device_holding_tensor = ti.tensor;
2416         }
2417     };
2418 
2419     Slice<int64_t> nsz;
2420     Slice<int64_t> nsd;
2421     at::IntArrayRef sz = self_info.tensor->sizes();
2422     at::IntArrayRef sd = self_info.tensor->strides();
2423 
2424     auto append_size = [&](int i) {
2425         if (has_dimpacks_or_none) {
2426             nsz.append(A, sz[i]);
2427             nsd.append(A, sd[i]);
2428         }
2429     };
2430     // std::cout << "self levels: " << self_info.levels << "\n";
2431 
2432     auto parse_nones = [&]() {
2433         while (input_it.size() && mpy::is_none(input_it[0])) {
2434             append_flat_handle(no_slice);
2435             nsz.append(A, 1);
2436             nsd.append(A, 0);
2437             input_it = input_it.slice(1);
2438         }
2439     };
2440 
2441 
2442     auto append_item = [&](int i, mpy::handle arg) {
2443         if (Dim::check_exact(arg)) {
2444             auto d = Dim::unchecked_wrap(arg);
2445             d->set_size(sz[i]);
2446             add_dim(d);
2447             append_size(i);
2448             append_flat_handle(arg);
2449             return;
2450         }
2451         auto info = TensorInfo::create(A, arg, false, false);
2452         if (info) {
2453             append_size(i);
2454             append_tensor_input(info);
2455             for (auto il : info.levels) {
2456                 if (!il.is_positional()) {
2457                     add_dim(il.dim());
2458                 }
2459             }
2460             return;
2461         }
2462 
2463         if (has_dimpacks_or_none) {
2464             Slice<mpy::handle> mp;
2465             if (maybe_dimpack(mp, arg)) {
2466                 // dim pack
2467                 Slice<mpy::hdl<Dim>> dim_pack;
2468                 for (auto d : mp) {
2469                     dim_pack.append(A, Dim::wrap(d));
2470                     add_dim(dim_pack.back());
2471                     append_flat_handle(dim_pack.back());
2472                 }
2473                 _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd);
2474                 return;
2475             }
2476         }
2477 
2478         append_size(i);
2479         append_flat_handle(arg);
2480     };
2481 
2482     // pair up the indexing expressions with dimension of self it indexes
2483     // self may have first-class dims, which do not participate the indexing.
2484     for (auto i : self_info.levels.enumerate()) {
2485         auto l = self_info.levels[i];
2486         auto idx = keys.index(l);
2487         if (idx) {
2488             append_item(i, values[*idx]);
2489         } else if (l.is_positional()) {
2490             // grab and index from the positional list
2491             parse_nones();
2492             if (!input_it.size()) {
2493                 // we might have fewer indices than tensor dimensions,
2494                 // which implicitly indexes the remaining dimensions with :
2495                 append_flat_handle(no_slice);
2496                 append_size(i);
2497             } else {
2498                 mpy::handle arg = input_it[0];
2499                 input_it = input_it.slice(1);
2500                 append_item(i, arg);
2501             }
2502         } else {
2503             add_dim(l.dim());
2504             append_flat_handle(l.dim());
2505             append_size(i);
2506         }
2507     }
2508     // any training Nones may have no existing dimension associated with them in self.
2509     parse_nones();
2510 
2511     // we have to restride the tensor to collapse dimension packs and introduce our none dimensions.
2512     if (has_dimpacks_or_none) {
2513         self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset()));
2514     }
2515 
2516 
2517     // figure out what the shape of the indexing tensors will be
2518     // and what the shape of the resulting tensor will be
2519     Slice<DimEntry> result_levels;
2520     Slice<DimEntry> index_levels;
2521     int64_t tensor_insert_point = -1;
2522     bool requires_getindex = false;
2523     auto mark_tensor_index = [&] {
2524         if (tensor_insert_point == -1) {
2525             tensor_insert_point = result_levels.size();
2526         } else if (tensor_insert_point != result_levels.size()) {
2527             tensor_insert_point = 0;
2528         }
2529     };
2530     for (auto i : flat_inputs.enumerate()) {
2531         auto inp = flat_inputs[i];
2532          if(tensor_inputs[i]) {
2533              requires_getindex = true;
2534              mark_tensor_index();
2535              for (auto l : tensor_inputs[i].levels) {
2536                  // std::cout << "Consider to add " << l << "\n";
2537                  if (!index_levels.contains(l)) {
2538                      index_levels.append(A, l);
2539                  }
2540              }
2541         } else if (Dim::check_exact(inp)) {
2542             auto d = Dim::unchecked_wrap(inp);
2543             // dimesions used once are just binding operations
2544             if (1 == seen_dims_nuses[*seen_dims.index(d)]) {
2545                 flat_inputs[i] = no_slice;
2546                 result_levels.append(A, d);
2547             } else {
2548                 requires_getindex = true;
2549                 flat_inputs[i] = mpy::handle();
2550                 tensor_inputs[i] = TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, TensorRef()};
2551                 if (!index_levels.contains(d)) {
2552                      index_levels.append(A, d);
2553                 }
2554                 mark_tensor_index();
2555             }
2556          } else {
2557             if (inp.ptr() != no_slice.ptr()) {
2558                 requires_getindex = true;
2559             }
2560             if (!mpy::is_int(inp)) {
2561                 // note: actual positional indexes are accurately computed later
2562                 result_levels.append(A, -1);
2563             }
2564          }
2565     }
2566 
2567     // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert
2568     // the indexing leveles into the result klevels at this spot
2569     if (tensor_insert_point != -1) {
2570         result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels);
2571     }
2572 
2573     // std::cout << "flat inputs: " << flat_inputs << "\n";
2574     // std::cout << "result_levels: " << result_levels << "\n";
2575     // std::cout << "index_levels: " << index_levels << "\n";
2576 
2577     // get all the tensors to be the right shape for indexing
2578     if (requires_getindex) {
2579         for (auto i : flat_inputs.enumerate()) {
2580             if (tensor_inputs[i]) {
2581                 AT_ASSERT(!flat_inputs[i].ptr());
2582                 // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n";
2583                 TensorRef t = tensor_inputs[i].tensor;
2584                 if (!tensor_inputs[i].has_device && device_holding_tensor) {
2585                     t = A.autorelease(t->to(device_holding_tensor->device()));
2586                 }
2587                 flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels));
2588             }
2589         }
2590     }
2591 
2592     // previously we didn't know how many positional dimensions there would be so we couldn't number them right
2593     // so fill it in now.
2594     auto seen_positionals = 0;
2595     for (auto i : result_levels.reversed_enumerate()) {
2596         if (result_levels[i].is_positional()) {
2597             result_levels[i] = -(++seen_positionals);
2598         }
2599     }
2600 
2601     return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device};
2602 }
2603 namespace{
__getitem__(Arena & A,mpy::handle self,mpy::handle index)2604 mpy::object __getitem__(Arena & A, mpy::handle self, mpy::handle index) {
2605     maybeInitializeGlobals();
2606     auto iinfo = getsetitem(A, self, index, has_dims(self));
2607     if (iinfo.can_call_original) {
2608         return mpy::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr()));
2609     }
2610 
2611     return invoke_getitem(A, iinfo);
2612 }
2613 
2614 
2615 
__setitem__(Arena & A,mpy::handle self,mpy::handle index,mpy::handle rhs)2616 void __setitem__(Arena & A, mpy::handle self, mpy::handle index, mpy::handle rhs) {
2617     maybeInitializeGlobals();
2618     auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs));
2619     if (iinfo.can_call_original) {
2620         if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) {
2621             throw mpy::exception_set();
2622         }
2623         return;
2624     }
2625 
2626     auto rhs_info = TensorInfo::create(A, rhs, false, false);
2627     if (rhs_info) { // otherwise rhs can be a scalar...
2628         for (auto l : rhs_info.levels) {
2629             if (!iinfo.result_levels.contains(l)) {
2630                 if (l.is_positional()) {
2631                     mpy::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim());
2632                 } else {
2633                     auto tup = levels_to_tuple(iinfo.result_levels);
2634                     mpy::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr());
2635                 }
2636             }
2637         }
2638         auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels);
2639         rhs = handle_from_tensor(A, rhs_matched);
2640     }
2641     self = handle_from_tensor(A, iinfo.self);
2642 
2643     if (iinfo.advanced_indexing) {
2644         auto tup = slice_to_tuple(iinfo.flat_inputs);
2645         if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) {
2646             throw mpy::exception_set();
2647         }
2648     } else {
2649         torch_Tensor_copy_.call(self, rhs);
2650     }
2651 }
2652 }
2653 
Tensor_getitem(PyObject * self,PyObject * index)2654 PyObject* Tensor_getitem(PyObject* self, PyObject* index) {
2655     Arena A;
2656     PY_BEGIN
2657     return __getitem__(A, self, index).release();
2658     PY_END(nullptr);
2659 }
2660 
Tensor_setitem(PyObject * self,PyObject * index,PyObject * value)2661 int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) {
2662     Arena A;
2663     PY_BEGIN
2664     __setitem__(A, self, index, value);
2665     return 0;
2666     PY_END(-1);
2667 }
2668 
2669 namespace{
py___getitem__(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2670 PyObject* py___getitem__(PyObject *_,
2671                       PyObject *const *args,
2672                       Py_ssize_t nargs,
2673                       PyObject *kwnames) {
2674     Arena A;
2675     PY_BEGIN
2676     AT_ASSERT(nargs == 2);
2677     return __getitem__(A, args[0], args[1]).release();
2678     PY_END(nullptr)
2679 }
2680 
py___setitem__(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2681 PyObject* py___setitem__(PyObject *_,
2682                       PyObject *const *args,
2683                       Py_ssize_t nargs,
2684                       PyObject *kwnames) {
2685     Arena A;
2686     PY_BEGIN
2687     AT_ASSERT(nargs == 3);
2688     __setitem__(A, args[0], args[1], args[2]);
2689     Py_RETURN_NONE;
2690     PY_END(nullptr)
2691 }
2692 
2693 
py_index(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2694 PyObject* py_index(PyObject *_,
2695                       PyObject *const *args,
2696                       Py_ssize_t nargs,
2697                       PyObject *kwnames) {
2698     Arena A;
2699     PY_BEGIN
2700     mpy::vector_args va(args, nargs, kwnames);
2701     mpy::handle self, dims, indices;
2702     va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3);
2703     return index(A, self, dims, indices).release();
2704     PY_END(nullptr)
2705 }
2706 
2707 
py_stack(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2708 PyObject* py_stack(PyObject *_,
2709                       PyObject *const *args,
2710                       Py_ssize_t nargs,
2711                       PyObject *kwnames) {
2712     Arena A;
2713     PY_BEGIN
2714     mpy::vector_args va(args, nargs, kwnames);
2715     mpy::handle tensors, new_dim, dim;
2716     va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2);
2717 
2718     Slice<DimEntry> result_levels;
2719     Slice<TensorInfo> infos;
2720     mpy::sequence_view sv(tensors);
2721     auto new_dim_d = Dim::wrap(new_dim);
2722     for (auto i : sv.enumerate()) {
2723         infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false));
2724         for (auto l : infos.back().levels) {
2725             if (!result_levels.contains(l)) {
2726                 result_levels.append(A, l);
2727             }
2728         }
2729     }
2730     new_dim_d->set_size(infos.size());
2731     std::vector<at::Tensor> inputs;
2732     inputs.reserve(infos.size());
2733     for (auto in : infos) {
2734         inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels));
2735     }
2736     auto ndim = ndim_of_levels(result_levels);
2737     int64_t rawdim = 0;
2738     if (dim.ptr()) {
2739         auto d = _wrap_dim(dim, ndim, false);
2740         auto idx = result_levels.index(d);
2741         if (!idx) {
2742             mpy::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr());
2743         }
2744         rawdim = *idx;
2745     }
2746     auto result = at::stack(inputs, rawdim);
2747     result_levels.insert(A, rawdim, new_dim_d);
2748     return Tensor::from_positional(A, std::move(result), result_levels, true).release();
2749     PY_END(nullptr)
2750 }
2751 
py_split(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2752 PyObject* py_split(PyObject *_,
2753                       PyObject *const *args,
2754                       Py_ssize_t nargs,
2755                       PyObject *kwnames) {
2756     Arena A;
2757     PY_BEGIN
2758     maybeInitializeGlobals();
2759     mpy::vector_args va(args, nargs, kwnames);
2760     mpy::handle self, split_size_or_sections, dim;
2761     va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2);
2762     bool dim_is_object = dim.ptr() && Dim::check_exact(dim);
2763     Slice<mpy::handle> sizes;
2764 
2765     bool all_dims = true;
2766     bool all_ints = true;
2767 
2768     if (!mpy::is_int(split_size_or_sections)) {
2769         mpy::sequence_view sv(split_size_or_sections);
2770         for (auto i : sv.enumerate()) {
2771             sizes.append(A, A.autorelease(sv[i]));
2772             if (Dim::check_exact(sizes.back())) {
2773                 all_ints = false;
2774             } else {
2775                 all_dims = false;
2776             }
2777         }
2778     }
2779     if (all_ints) {
2780         if (dim_is_object) {
2781             mpy::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions.");
2782         }
2783         // call original split (if self has dimensions this will use torch function to do the split)
2784         return torch_Tensor_split.call_vector(mpy::vector_args(args, nargs, kwnames)).release();
2785     }
2786     if (!all_dims) {
2787         mpy::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix");
2788     }
2789 
2790     auto self_info = TensorInfo::create(A, self, false);
2791     auto ndim = self_info.ndim();
2792     if (!dim_is_object&& ndim == 0) {
2793         mpy::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor");
2794     }
2795     DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim;
2796 
2797     auto idx = self_info.levels.index(dim_l);
2798     if (!idx) {
2799         if (!dim.ptr()) {
2800             dim = A.autorelease(mpy::from_int(0));
2801         }
2802         mpy::raise_error(PyExc_TypeError, "tensor does not comtain dimension %R", dim.ptr());
2803     }
2804     Slice<int64_t> indices;
2805 
2806     int64_t total_size = 0;
2807     Slice<int64_t> unbound;
2808     for (auto i : sizes.enumerate()) {
2809         auto d = Dim::unchecked_wrap(sizes[i]);
2810         if (d->is_bound()) {
2811             indices.append(A, d->size());
2812             total_size += indices.back();
2813         } else {
2814             indices.append(A, 0);
2815             unbound.append(A, i);
2816         }
2817     }
2818     auto tensor_size = self_info.tensor->sizes()[*idx];
2819 
2820     if (unbound.size()) {
2821         if (total_size > tensor_size) {
2822            mpy::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size));
2823         }
2824         auto remaining_size = tensor_size - total_size;
2825         auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size();
2826         for (auto u : unbound) {
2827             auto sz = std::min(chunk_size, remaining_size);
2828             Dim::unchecked_wrap(sizes[u])->set_size(sz);
2829             indices[u] = sz;
2830             remaining_size -= sz;
2831         }
2832     } else if (tensor_size != total_size) {
2833         mpy::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size));
2834     }
2835 
2836     auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx);
2837     mpy::tuple result(result_tensors.size());
2838     Slice<DimEntry> new_levels;
2839     new_levels.extend(A, self_info.levels);
2840     for (auto i : sizes.enumerate()) {
2841         new_levels[*idx] = Dim::unchecked_wrap(sizes[i]);
2842         result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true));
2843     }
2844 
2845     return result.release();
2846 
2847     PY_END(nullptr)
2848 }
2849 
_wrap_dims(Arena & A,mpy::handle d,size_t N,bool keepdim)2850 Slice<DimEntry> _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) {
2851     auto de = _wrap_dim(d, N, keepdim);
2852     Slice<DimEntry> r;
2853     if (!de.is_none()) {
2854         r.append(A, de);
2855     } else {
2856         mpy::sequence_view sq(d);
2857         for (auto i : sq.enumerate()) {
2858             r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim));
2859         }
2860     }
2861     return r;
2862 }
2863 
2864 struct WrappedOperator : public mpy::base<WrappedOperator> {
2865     mpy::object orig;
2866     PyMethodDef method_def;
2867     mpy::object name, doc;
2868 
2869     bool is_pointwise = false;
2870     int64_t dim_offset = 0;
2871     int64_t keepdim_offset = 1;
2872     std::string dim_name;
2873     bool single_dim = false;
2874     bool reduce = true;
2875 
2876     static PyTypeObject Type;
2877 
init__anon5de4c5481f11::WrappedOperator2878     void init(mpy::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") {
2879         orig = std::move(orig_);
2880         method_def.ml_meth = wrapper_implementation;
2881         name = orig.attr("__name__");
2882         doc = orig.attr("__doc__");
2883         dim_name = std::move(dim_name_);
2884         if (!mpy::is_none(doc) && !dim_name.empty()) {
2885             doc = mpy::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str());
2886         }
2887         method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr());
2888         method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr());
2889         method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS;
2890     }
2891 
function__anon5de4c5481f11::WrappedOperator2892     mpy::object function() {
2893         return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr()));
2894     }
2895 
2896 };
2897 }
2898 
2899 PyTypeObject WrappedOperator::Type = {
2900     PyVarObject_HEAD_INIT(NULL, 0)
2901     "_C.WrappedOperator",               /* tp_name */
2902     sizeof(WrappedOperator),               /* tp_basicsize */
2903     0,                              /* tp_itemsize */
2904     WrappedOperator::dealloc_stub,      /* tp_dealloc */
2905     0,                              /* tp_vectorcall_offset */
2906     0,                              /* tp_getattr */
2907     0,                              /* tp_setattr */
2908     0,                              /* tp_as_async */
2909     0,           /* tp_repr */
2910     0,                 /* tp_as_number */
2911     0,                 /* tp_as_sequence */
2912     0,             /* tp_as_mapping */
2913     0,      /* tp_hash */
2914     0,                              /* tp_call */
2915     0,                              /* tp_str */
2916     0,                              /* tp_getattro */
2917     0,                              /* tp_setattro */
2918     0,                              /* tp_as_buffer */
2919     Py_TPFLAGS_DEFAULT, /* tp_flags */
2920     "Wrapped Object Holder",                   /* tp_doc */
2921     0,                              /* tp_traverse */
2922     0,                              /* tp_clear */
2923     0,  /* tp_richcompare */
2924     0,                              /* tp_weaklistoffset */
2925     0,                              /* tp_iter */
2926     0,                              /* tp_iternext */
2927     0,                /* tp_methods */
2928     0,                              /* tp_members */
2929     0,             /* tp_getset */
2930     0,                              /* tp_base */
2931     0,                              /* tp_dict */
2932     0,                              /* tp_descr_get */
2933     0,                              /* tp_descr_set */
2934     0,                              /* tp_dictoffset */
2935     0,            /* tp_init */
2936     0,                              /* tp_alloc */
2937     WrappedOperator::new_stub,                      /* tp_new */
2938 };
2939 
2940 namespace{
patched_dim_method(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2941 PyObject* patched_dim_method(PyObject * self_,
2942                       PyObject *const *args,
2943                       Py_ssize_t nargs,
2944                       PyObject *kwnames) {
2945     Arena A;
2946     auto self = WrappedOperator::unchecked_wrap(self_);
2947     PY_BEGIN
2948 
2949     mpy::vector_args va(args, nargs, kwnames);
2950 
2951     auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle {
2952         auto offset = offset_ + 1; // do not include self
2953         auto idx = va.index(name, offset);
2954         return idx == -1 ? mpy::handle() : va[idx];
2955     };
2956     Slice<mpy::handle> patched_args;
2957     patched_args.extend(A, va.begin(), va.end());
2958     auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) {
2959         auto offset = offset_ + 1; // do not include self
2960         auto idx = va.index(name, offset);
2961         if (idx == -1) {
2962             mpy::raise_error(PyExc_ValueError, "Missing argument %s", name);
2963         }
2964         patched_args[idx] = value;
2965     };
2966 
2967     auto dim = _getarg(self->dim_name.c_str(), self->dim_offset);
2968     if (!dim.ptr()) {
2969         auto info = TensorInfo::create(A, args[0], true);
2970         EnableAllLayers l(A, info.levels);
2971         l.inplace_update_layers(info.batchedtensor, info.levels);
2972         patched_args[0] = handle_from_tensor(A, info.batchedtensor);
2973         auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames);
2974         return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release();
2975     }
2976 
2977     auto info = TensorInfo::create(A, args[0]);
2978     auto keepdim = false;
2979     if (self->reduce) {
2980         auto py_keepdim = _getarg("keepdim", self->keepdim_offset);
2981         if (py_keepdim.ptr()) {
2982             keepdim = mpy::to_bool(py_keepdim);
2983         }
2984     }
2985 
2986     auto ndim = info.ndim();
2987     auto dims = _wrap_dims(A, dim, ndim, keepdim);
2988     Slice<int64_t> dim_indices;
2989     auto seen = A.allocate<bool>(info.levels.size());
2990     std::fill(seen, seen + info.levels.size(), false);
2991 
2992     for (auto d : dims) {
2993         auto midx = info.levels.index(d);
2994         if (!midx) {
2995             auto tup = levels_to_tuple(info.levels);
2996             mpy::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr());
2997         }
2998         seen[*midx] = true;
2999         dim_indices.append(A, *midx);
3000     }
3001     Slice<DimEntry> new_levels;
3002     if (self->reduce && !keepdim) {
3003         for (auto i : info.levels.enumerate()) {
3004             if (!seen[i]) {
3005                 new_levels.append(A, info.levels[i]);
3006             }
3007         }
3008     } else {
3009         new_levels = info.levels;
3010     }
3011     mpy::object py_indices;
3012     if (dim_indices.size() == 1) {
3013         py_indices = mpy::from_int(dim_indices[0]);
3014     } else {
3015         mpy::tuple tup(dim_indices.size());
3016         for (auto i : dim_indices.enumerate()) {
3017             tup.set(i, mpy::from_int(dim_indices[i]));
3018         }
3019         py_indices = std::move(tup);
3020     }
3021     _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices);
3022     patched_args[0] = handle_from_tensor(A, info.tensor);
3023     auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames);
3024     auto wrap = [&](mpy::handle h) {
3025         if (THPVariable_Check(h.ptr())) {
3026             return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device));
3027         }
3028         return h;
3029     };
3030     return tree_map(A, wrap, r).release();
3031     PY_END(nullptr)
3032 }
3033 
_wrap(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3034 PyObject* _wrap(PyObject * self_,
3035                       PyObject *const *args,
3036                       Py_ssize_t nargs,
3037                       PyObject *kwnames) {
3038     Arena A;
3039     PY_BEGIN
3040 
3041     #define ARGS(_) _(mpy::handle, orig) _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \
3042                     _(mpy::handle, dim_name) _(mpy::handle, single_dim) _(mpy::handle, reduce)
3043     MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS)
3044 
3045     std::string dim_name_str;
3046     if (dim_name.ptr()) {
3047         dim_name_str = PyUnicode_AsUTF8(dim_name.ptr());
3048     } else {
3049         dim_name_str = "dim";
3050     }
3051     auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str));
3052     if (dim_offset.ptr()) {
3053         info->dim_offset = mpy::to_int(dim_offset);
3054     }
3055     if (keepdim_offset.ptr()) {
3056         info->keepdim_offset = mpy::to_int(keepdim_offset);
3057     }
3058 
3059     if (single_dim.ptr()) {
3060         info->single_dim = mpy::to_bool(single_dim);
3061     }
3062     if (reduce.ptr()) {
3063         info->reduce = mpy::to_bool(reduce);
3064     }
3065     return info->function().release();
3066     #undef ARGS
3067 
3068     PY_END(nullptr)
3069 }
3070 
call_torch_function(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3071 PyObject* call_torch_function(PyObject *self,
3072                       PyObject *const *args,
3073                       Py_ssize_t nargs,
3074                       PyObject *kwnames) {
3075     PY_BEGIN
3076     Arena A;
3077     maybeInitializeGlobals();
3078     auto info = WrappedOperator::unchecked_wrap(self);
3079     return __torch_function__(A, info->orig, mpy::vector_args(args, nargs, kwnames), info->is_pointwise).release();
3080     PY_END(nullptr)
3081 }
3082 
_wrap_method(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3083 PyObject* _wrap_method(PyObject *self,
3084                       PyObject *const *args,
3085                       Py_ssize_t nargs,
3086                       PyObject *kwnames) {
3087     PY_BEGIN
3088     AT_ASSERT(nargs == 2);
3089     // XXX - ignore python function wrapped, we will call torch function directly
3090     mpy::handle orig = args[0];
3091     if (!pointwise.ptr()) {
3092         auto dim = mpy::import("functorch.dim");
3093         pointwise = dim.attr("pointwise");
3094     }
3095     auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) call_torch_function);
3096     info->is_pointwise = pointwise.contains(orig);
3097     return PyInstanceMethod_New(info->function().release());
3098     PY_END(nullptr);
3099 }
3100 
3101 
Tensor_sum(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3102 PyObject* Tensor_sum(PyObject * self_,
3103                       PyObject *const *args,
3104                       Py_ssize_t nargs,
3105                       PyObject *kwnames) {
3106     Arena A;
3107     PY_BEGIN
3108     maybeInitializeGlobals();
3109     mpy::vector_args va(args, nargs, kwnames);
3110     auto self_ = Tensor::unchecked_wrap(args[0]);
3111     auto d = self_->delayed();
3112     if (!d) {
3113         return _Tensor_sum.call_vector(va).release();
3114     }
3115     mpy::handle self, dim, keepdim, dtype;
3116     va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1);
3117 
3118     if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) {
3119         // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n";
3120         return _Tensor_sum.call_vector(va).release();
3121     }
3122     auto levels = self_->levels();
3123 
3124     auto N = ndim_of_levels(levels);
3125     auto reduced_dims = _wrap_dims(A, dim, N, false);
3126 
3127     return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release();
3128     PY_END(nullptr)
3129 }
3130 
_parse_test(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3131 PyObject* _parse_test(PyObject * self_,
3132                       PyObject *const *args,
3133                       Py_ssize_t nargs,
3134                       PyObject *kwnames) {
3135     PY_BEGIN
3136     maybeInitializeGlobals();
3137 
3138     int required = mpy::to_int(args[0]);
3139     int kwonly = mpy::to_int(args[1]);
3140 
3141     mpy::vector_args va(args + 2, nargs - 2, kwnames);
3142 
3143 
3144     mpy::handle a, b, c, d;
3145     va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly);
3146     mpy::tuple r(4);
3147     r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None));
3148     r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None));
3149     r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None));
3150     r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None));
3151     return r.release();
3152 
3153     PY_END(nullptr)
3154 }
3155 
_set_pointwise_optimize(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3156 PyObject* _set_pointwise_optimize(PyObject * self_,
3157                       PyObject *const *args,
3158                       Py_ssize_t nargs,
3159                       PyObject *kwnames) {
3160     PY_BEGIN
3161     mpy::handle value;
3162     mpy::vector_args va(args, nargs, kwnames);
3163     va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1);
3164     pointwise_optimize = mpy::to_bool(value);
3165     Py_RETURN_NONE;
3166     PY_END(nullptr)
3167 }
3168 
_patch_tensor_class(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3169 PyObject* _patch_tensor_class(PyObject * self_,
3170                       PyObject *const *args,
3171                       Py_ssize_t nargs,
3172                       PyObject *kwnames) {
3173     PY_BEGIN
3174 
3175     auto torch = mpy::import("torch");
3176     auto py_TensorBase = torch.attr("_C").attr("TensorBase");
3177     replaceMappingIfMatches(py_TensorBase);
3178 
3179     Py_RETURN_NONE;
3180     PY_END(nullptr)
3181 }
3182 
3183 
3184 const char* dims_doc = R"""(
3185 dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...]
3186 
3187 Creates and returns one or more Dim objects.
3188 
3189 Arg:
3190     n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified.
3191     sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be
3192       created, specifying each dimensions size, or None to leave the size unset.
3193 
3194 Example::
3195     >>> batch, channel, width, height = dims(4)
3196     >>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224])
3197 )""";
3198 
3199 PyMethodDef methods[] = {
3200     {"dims", (PyCFunction)(void*) _dims<create_dim>, METH_FASTCALL | METH_KEYWORDS, dims_doc},
3201     {"dimlists", (PyCFunction)(void*) _dims<create_dimlist>, METH_FASTCALL | METH_KEYWORDS},
3202     {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS},
3203     {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS},
3204     {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS},
3205     {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS},
3206     {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS},
3207     {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS},
3208     {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS},
3209     {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS},
3210     {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS},
3211     {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS},
3212     {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS},
3213     {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS},
3214     {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS},
3215     {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS},
3216     {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS},
3217     {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS},
3218     {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS},
3219     {NULL, NULL, 0, NULL}        /* Sentinel */
3220 };
3221 
3222 struct PyModuleDef module_def = {
3223     PyModuleDef_HEAD_INIT,
3224     "_C",   /* name of module */
3225     NULL, /* module documentation, may be NULL */
3226     -1,       /* size of per-interpreter state of the module,
3227                  or -1 if the module keeps state in global variables. */
3228     methods
3229 };
3230 }
3231 
Dim_init()3232 PyObject* Dim_init() {
3233     Arena A;
3234     try {
3235         mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def));
3236         Dim::ready(mod, "Dim");
3237         DimList::ready(mod, "DimList");
3238         Tensor::ready(mod, "Tensor");
3239         WrappedOperator::ready(mod, "_WrappedOperator");
3240         Py_INCREF(&PyInstanceMethod_Type);
3241         PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type);
3242 
3243         initializeGlobals(A);
3244         return mod.release();
3245     } catch(mpy::exception_set& err) {
3246         return nullptr;
3247     }
3248 }
3249 
3250 #endif
3251