1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <torch/csrc/utils/python_compat.h>
8
9
10 // Many APIs have changed/don't exist anymore
11 #if IS_PYTHON_3_12_PLUS
12
13 #include "dim.h"
14
15 // Re-enable this some day
Dim_init()16 PyObject* Dim_init() {
17 PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12");
18 return nullptr;
19 }
20
21 #else
22
23 #include "minpybind.h"
24 #include <frameobject.h>
25 #include <opcode.h>
26 #include <utility>
27 #include <new>
28 #include <iostream>
29 #include <vector>
30 //#include <torch/csrc/autograd/python_variable.h>
31 #include <torch/csrc/Export.h>
32 #include <ATen/functorch/BatchedTensorImpl.h>
33 #include <ATen/functorch/DynamicLayer.h>
34 #include <ATen/ATen.h>
35 #include <memory>
36 #include "arena.h"
37 #include "dim.h"
38 #include "python_variable_simple.h"
39
40 #if IS_PYTHON_3_11_PLUS
41 #define Py_BUILD_CORE
42 #include "internal/pycore_opcode.h"
43 #undef Py_BUILD_CORE
44 #endif
45
46 // C++ API functions for objects to
47 // * construct the object, returning a ref-counted handle
48 // * The actual API, with methods that take/return C-typed values
49
50 // extend minpybind.h to include
51 // * typed handles so that -> can get to their raw API
52 // * object/handle distinction for the typed handles
53
54 // class Dim: ---------------
55 mpy::handle torch_Tensor___mul__;
56 mpy::handle _Tensor;
57 mpy::handle _Tensor_sum;
58 mpy::handle NamedTuple;
59 mpy::dict_view pointwise;
60 mpy::handle torch_Tensor_expand;
61 binaryfunc THPVariable_getitem;
62 objobjargproc THPVariable_setitem;
63 mpy::handle no_slice;
64 PyTypeObject* torch_Tensor;
65 mpy::handle torch_Tensor_copy_;
66 mpy::handle torch_Tensor_split;
67 bool pointwise_optimize = true;
68 PyTypeObject* DimType = nullptr;
69
70 PyObject* Tensor_getitem(PyObject* self, PyObject* index);
71 int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value);
72
73 namespace{
maybeInitializeGlobals()74 void maybeInitializeGlobals() {
75 // globals that depend on the python dim library,
76 // which we can't lookup until we finish initializing the _C module
77 if (_Tensor.ptr()) {
78 return;
79 }
80 auto dim = mpy::import("functorch.dim");
81 _Tensor = dim.attr("_Tensor");
82 pointwise = dim.attr("pointwise");
83 _Tensor_sum = _Tensor.attr("sum");
84 DimType = (PyTypeObject*) mpy::import("functorch.dim").attr("Dim").ptr();
85 }
86
replaceMappingIfMatches(mpy::handle tp)87 void replaceMappingIfMatches(mpy::handle tp) {
88 auto T = (PyTypeObject*) tp.ptr();
89 bool recurse = false;
90 if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) {
91 T->tp_as_mapping->mp_subscript = Tensor_getitem;
92 recurse = true;
93 }
94 if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) {
95 T->tp_as_mapping->mp_ass_subscript = Tensor_setitem;
96 recurse = true;
97 }
98 if (recurse) {
99 auto result = tp.attr("__subclasses__").call();
100 mpy::list_view lv(result);
101 for (auto i : lv.enumerate()) {
102 replaceMappingIfMatches(lv[i]);
103 }
104 }
105 }
106
initializeGlobals(Arena & A)107 void initializeGlobals(Arena & A) {
108 auto torch = mpy::import("torch");
109 torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr();
110 torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__");
111
112 torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand");
113 torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split");
114 torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_");
115 auto py_TensorBase = torch.attr("_C").attr("TensorBase");
116 auto TensorBase = (PyTypeObject*) py_TensorBase.ptr();
117 THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript;
118 THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript;
119 NamedTuple = mpy::import("typing").attr("NamedTuple");
120 no_slice = PySlice_New(NULL, NULL, NULL);
121
122 }
123
124 mpy::handle DimensionBindError_;
DimensionBindError()125 mpy::handle DimensionBindError() {
126 if(!DimensionBindError_.ptr()) {
127 DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError");
128 }
129 return DimensionBindError_;
130 }
131
132 static int64_t n_dims_created = 65;
133
134 struct Dim : public mpy::base<Dim> {
135 int64_t level_; // for stable comparisons in prototype
136 mpy::object name_;
Dim__anon5de4c5480111::Dim137 Dim()
138 : level_(n_dims_created++) {}
init__anon5de4c5480111::Dim139 void init(mpy::object name, int64_t s = -1) {
140 name_ = std::move(name);
141 size_ = s;
142 }
143
check_exact__anon5de4c5480111::Dim144 static bool check_exact(mpy::handle v) {
145 return Py_TYPE(v.ptr()) == DimType;
146 }
147
size__anon5de4c5480111::Dim148 int64_t size() const {
149 if (size_ == -1) {
150 mpy::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr());
151 }
152 return size_;
153 }
set_size__anon5de4c5480111::Dim154 void set_size(int64_t v) {
155 if (size_ == -1) {
156 size_ = v;
157 } else if(size_ != v) {
158 mpy::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v);
159 }
160 }
is_bound__anon5de4c5480111::Dim161 bool is_bound() const {
162 return size_ != -1;
163 }
create__anon5de4c5480111::Dim164 static mpy::obj<Dim> create(mpy::object name, int64_t s = -1) {
165 if (!DimType) {
166 maybeInitializeGlobals();
167 }
168 auto r = Dim::alloc(DimType);
169 r->init(std::move(name), s);
170 return r;
171 }
172 static PyTypeObject Type;
range__anon5de4c5480111::Dim173 const at::Tensor& range() {
174 if (!range_.defined()) {
175 range_ = at::arange(size());
176 }
177 return range_;
178 }
batchtensor__anon5de4c5480111::Dim179 const at::Tensor& batchtensor() {
180 if (!batchtensor_.defined()) {
181 batchtensor_ = at::functorch::addBatchDim(range(), 0, level_);
182 }
183 return batchtensor_;
184 }
185 private:
186 int64_t size_{-1};
187 at::Tensor range_;
188 at::Tensor batchtensor_;
189 };
190
191
192 struct DimEntry {
193 // union of either a negative number indicating which dimension this is from the rhs,
194 // or a pointer to a first-class dimension.
195 // pointers do not have their highest bit set, so checking the number is negative tells us
196 // that it is not a dim.
is_positional__anon5de4c5480111::DimEntry197 bool is_positional() const {
198 return data_ < 0;
199 }
is_none__anon5de4c5480111::DimEntry200 bool is_none() const {
201 return data_ == 0;
202 }
position__anon5de4c5480111::DimEntry203 int64_t position() const {
204 return data_;
205 }
dim__anon5de4c5480111::DimEntry206 mpy::hdl<Dim> dim() const {
207 Dim* result;
208 std::memcpy(&result, &data_, sizeof(Dim*));
209 return mpy::hdl<Dim>(result);
210 }
211
DimEntry__anon5de4c5480111::DimEntry212 DimEntry()
213 : data_(0) {}
214
DimEntry__anon5de4c5480111::DimEntry215 DimEntry(int64_t pos)
216 : data_(pos) {
217 AT_ASSERT(pos < 0);
218 }
DimEntry__anon5de4c5480111::DimEntry219 DimEntry(mpy::hdl<Dim> d) {
220 std::memcpy(&data_, &d, sizeof(int64_t));
221 }
operator ==__anon5de4c5480111::DimEntry222 bool operator==(const DimEntry& rhs) const {
223 return data_ == rhs.data_;
224 }
225 private:
226 int64_t data_;
227 };
228
229 // Dim wrapper methods
_wrap_dim(mpy::handle d,size_t N,bool keepdim)230 DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) {
231 if (Dim::check(d)) {
232 if (keepdim) {
233 mpy::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True");
234 }
235 return Dim::unchecked_wrap(d);
236 } else if (mpy::is_int(d)) {
237 auto i = mpy::to_int(d);
238 while (i >= 0) {
239 i -= N;
240 }
241 return i;
242 } else {
243 return DimEntry();
244 }
245 }
246
247
Dim_init(mpy::hdl<Dim> self,PyObject * args,PyObject * kwds)248 int Dim_init(mpy::hdl<Dim> self, PyObject *args, PyObject *kwds) {
249 PY_BEGIN
250 static constexpr const char* kwlist[] = {"name", "size", nullptr};
251 mpy::handle name;
252 mpy::handle size = nullptr;
253 if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", const_cast<char **>(kwlist), &name, &size)) {
254 return -1;
255 }
256 self->init(mpy::object::borrow(name), (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1);
257 return 0;
258 PY_END(-1)
259 }
260
Dim_repr(Dim * self)261 PyObject* Dim_repr(Dim* self) {
262 PY_BEGIN
263 mpy::object name = (self->name_.ptr()) ? self->name_ : mpy::unicode_from_string("<uninitialized dim>");
264 return name.release();
265 PY_END(nullptr)
266 }
267
268
Dim_getsize(Dim * self,void *)269 PyObject* Dim_getsize(Dim* self, void*) {
270 PY_BEGIN
271 return mpy::from_int(self->size()).release();
272 PY_END(nullptr)
273 }
274
Dim_setsize(Dim * self,PyObject * size,void *)275 int Dim_setsize(Dim* self, PyObject* size, void*) {
276 PY_BEGIN
277 self->set_size(mpy::to_int(size));
278 return 0;
279 PY_END(-1)
280 }
281
Dim_getis_bound(Dim * self,void *)282 PyObject* Dim_getis_bound(Dim* self, void*) {
283 return PyBool_FromLong(self->is_bound());
284 }
285
Dim_getlevel(Dim * self,void *)286 PyObject* Dim_getlevel(Dim* self, void*) {
287 return PyLong_FromLong(self->level_);
288 }
289
Dim_get_levels(Dim * self,void *)290 PyObject* Dim_get_levels(Dim* self, void*) {
291 mpy::tuple t(1);
292 t.set(0, mpy::object::borrow(self->ptr()));
293 return t.release();
294 }
295
Dim_get_has_device(Dim * self,void *)296 PyObject* Dim_get_has_device(Dim* self, void*) {
297 Py_RETURN_FALSE;
298 }
299
Dim_get_tensor(Dim * self,void *)300 PyObject* Dim_get_tensor(Dim* self, void*) {
301 return THPVariable_Wrap(self->range());
302 }
303
Dim_get_batchtensor(Dim * self,void *)304 PyObject* Dim_get_batchtensor(Dim* self, void*) {
305 return THPVariable_Wrap(self->batchtensor());
306 }
307
308
309 PyGetSetDef Dim_getsetters[] = {
310 {"size", (getter) Dim_getsize, (setter) Dim_setsize,
311 "Dimension size", NULL},
312 {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL},
313 {"_level", (getter) Dim_getlevel, NULL, "_level", NULL},
314 {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL},
315 {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL},
316 {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL},
317 {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL},
__anon5de4c5480202() 318 {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_int(1).release(); }, NULL, "ndim", NULL},
319 {NULL} /* Sentinel */
320 };
321 }
322 PyTypeObject Dim::Type = {
323 PyVarObject_HEAD_INIT(NULL, 0)
324 "_C.Dim", /* tp_name */
325 sizeof(Dim), /* tp_basicsize */
326 0, /* tp_itemsize */
327 Dim::dealloc_stub, /* tp_dealloc */
328 0, /* tp_vectorcall_offset */
329 0, /* tp_getattr */
330 0, /* tp_setattr */
331 0, /* tp_as_async */
332 (reprfunc)Dim_repr, /* tp_repr */
333 0, /* tp_as_number */
334 0, /* tp_as_sequence */
335 0, /* tp_as_mapping */
336 0, /* tp_hash */
337 0, /* tp_call */
338 0, /* tp_str */
339 0, /* tp_getattro */
340 0, /* tp_setattro */
341 0, /* tp_as_buffer */
342 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
343 "Dim Object", /* tp_doc */
344 0, /* tp_traverse */
345 0, /* tp_clear */
346 0, /* tp_richcompare */
347 0, /* tp_weaklistoffset */
348 0, /* tp_iter */
349 0, /* tp_iternext */
350 0, /* tp_methods */
351 0, /* tp_members */
352 Dim_getsetters, /* tp_getset */
353 0, /* tp_base */
354 0, /* tp_dict */
355 0, /* tp_descr_get */
356 0, /* tp_descr_set */
357 0, /* tp_dictoffset */
358 (initproc)(void*)static_cast<int(*)(mpy::hdl<Dim>,PyObject*,PyObject*)>(Dim_init), /* tp_init */
359 0, /* tp_alloc */
360 Dim::new_stub, /* tp_new */
361 };
362
363 // class DimList ------------
364
365 struct DimList : public mpy::base<DimList> {
366 mpy::object name_;
367 std::vector<mpy::obj<Dim>> dims_;
368 static PyTypeObject Type;
initDimList369 void init(mpy::object name) {
370 name_ = std::move(name);
371 }
set_dimsDimList372 void set_dims(std::vector<mpy::obj<Dim>> dims) {
373 bound_ = true;
374 dims_ = std::move(dims);
375 }
is_boundDimList376 bool is_bound() {
377 return bound_;
378 }
bind_lenDimList379 void bind_len(int64_t size) {
380 if (bound_) {
381 int64_t b_size = dims_.size();
382 if (b_size != size) {
383 mpy::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size);
384 }
385 } else {
386 bound_ = true;
387 dims_.resize(size);
388 for (Py_ssize_t i = 0; i < size; ++i) {
389 dims_[i] = Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i));
390 }
391 }
392 }
sizeDimList393 int64_t size() const {
394 if (!bound_) {
395 mpy::raise_error(DimensionBindError(), "DimList not bound");
396 }
397 return dims_.size();
398 }
set_boundDimList399 void set_bound(bool b) {
400 bound_ = b;
401 }
402 private:
403 bool bound_ = false;
404 };
405
406
407 static int DimList_init(DimList *self, PyObject *args, PyObject *kwds);
408
DimList_repr(DimList * self)409 static PyObject* DimList_repr(DimList* self) {
410 PY_BEGIN
411 if (self->is_bound()) {
412 size_t size = self->dims_.size();
413 mpy::tuple t(size);
414 for(size_t i = 0; i < size; ++i) {
415 t.set(i, self->dims_[i]);
416 }
417 return mpy::repr(t).release();
418 } else if(!mpy::is_none(self->name_)) {
419 return mpy::unicode_from_format("*%S", self->name_.ptr()).release();
420 } else {
421 return mpy::unicode_from_string("<unbound_dimlist>").release();
422 }
423 PY_END(nullptr)
424 }
425
DimList_bind(DimList * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)426 static PyObject* DimList_bind(DimList *self,
427 PyObject *const *args,
428 Py_ssize_t nargs,
429 PyObject *kwnames) {
430 PY_BEGIN
431 mpy::handle sizes;
432 static const char * const _keywords[] = {"sizes", nullptr};
433 static _PyArg_Parser parser = {"O", _keywords, 0};
434 if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) {
435 return nullptr;
436 }
437 if (!mpy::is_sequence(sizes)) {
438 mpy::raise_error(PyExc_ValueError, "expected a sequence");
439 }
440 mpy::sequence_view seq = sizes;
441 auto size = seq.size();
442 self->bind_len(size);
443 for (Py_ssize_t i = 0; i < size; ++i) {
444 self->dims_[i]->set_size(mpy::to_int(seq[i]));
445 }
446 Py_RETURN_NONE;
447 PY_END(nullptr)
448 }
449
DimList_bind_len(DimList * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)450 static PyObject* DimList_bind_len(DimList *self,
451 PyObject *const *args,
452 Py_ssize_t nargs,
453 PyObject *kwnames) {
454 PY_BEGIN
455 int size;
456 static const char * const _keywords[] = {"N", nullptr};
457 static _PyArg_Parser parser = {"i", _keywords, 0};
458 if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) {
459 return nullptr;
460 }
461 self->bind_len(size);
462 Py_RETURN_NONE;
463 PY_END(nullptr)
464 }
465
466 static PyMethodDef DimList_methods[] = {
467 {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS},
468 {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS},
469 {NULL, NULL, 0, NULL} /* Sentinel */
470 };
471
472
DimList_len(DimList * self)473 static Py_ssize_t DimList_len(DimList* self) {
474 PY_BEGIN
475 return self->size();
476 PY_END(-1)
477 }
478
DimList_item(DimList * self,Py_ssize_t idx)479 static PyObject * DimList_item(DimList* self, Py_ssize_t idx) {
480 PY_BEGIN
481 if (!self->is_bound()) {
482 mpy::raise_error(DimensionBindError(), "DimList not bound");
483 }
484 if (idx < 0 || (size_t) idx >= self->dims_.size()) {
485 mpy::raise_error(PyExc_IndexError, "index out of bounds");
486 }
487 mpy::object r = self->dims_[idx];
488 return r.release();
489 PY_END(nullptr)
490 }
491
492 PySequenceMethods DimList_seq {
493 (lenfunc) DimList_len, //lenfunc sq_length;
494 0, //binaryfunc sq_concat;
495 0, //ssizeargfunc sq_repeat;
496 (ssizeargfunc) DimList_item, //ssizeargfunc sq_item;
497 0, //void *was_sq_slice;
498 0, //ssizeobjargproc sq_ass_item;
499 0, //void *was_sq_ass_slice;
500 0, //objobjproc sq_contains;
501
502 0, //binaryfunc sq_inplace_concat;
503 0, //ssizeargfunc sq_inplace_repeat;
504 };
505
DimList_getis_bound(DimList * self,void *)506 static PyObject* DimList_getis_bound(DimList* self, void*) {
507 return PyBool_FromLong(self->is_bound());
508 }
509
510 static PyGetSetDef DimList_getsetters[] = {
511 {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL},
512 {NULL} /* Sentinel */
513 };
514
515
DimList_subscript(DimList * self,mpy::handle idx)516 static PyObject* DimList_subscript(DimList* self, mpy::handle idx) {
517 PY_BEGIN
518 if (mpy::is_int(idx)) {
519 return DimList_item(self, mpy::to_int(idx));
520 } else if (mpy::is_slice(idx)) {
521 if (!self->is_bound()) {
522 mpy::raise_error(DimensionBindError(), "DimList not bound");
523 }
524 mpy::slice_view s(idx, self->dims_.size());
525 mpy::tuple r(s.slicelength);
526 for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) {
527 r.set(j++, self->dims_[i]);
528 }
529 return r.release();
530 } else {
531 mpy::raise_error(PyExc_ValueError, "expected an int or a slice");
532 return nullptr;
533 }
534 PY_END(nullptr)
535 }
536
537 PyMappingMethods DimList_mapping = {
538 0, //lenfunc mp_length;
539 (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript;
540 0, //objobjargproc mp_ass_subscript;
541 };
542
543
544
545 PyTypeObject DimList::Type = {
546 PyVarObject_HEAD_INIT(NULL, 0)
547 "_C.DimList", /* tp_name */
548 sizeof(DimList), /* tp_basicsize */
549 0, /* tp_itemsize */
550 DimList::dealloc_stub, /* tp_dealloc */
551 0, /* tp_vectorcall_offset */
552 0, /* tp_getattr */
553 0, /* tp_setattr */
554 0, /* tp_as_async */
555 (reprfunc)DimList_repr, /* tp_repr */
556 0, /* tp_as_number */
557 &DimList_seq, /* tp_as_sequence */
558 &DimList_mapping, /* tp_as_mapping */
559 0, /* tp_hash */
560 0, /* tp_call */
561 0, /* tp_str */
562 0, /* tp_getattro */
563 0, /* tp_setattro */
564 0, /* tp_as_buffer */
565 0, /* tp_flags */
566 "DimList Object", /* tp_doc */
567 0, /* tp_traverse */
568 0, /* tp_clear */
569 0, /* tp_richcompare */
570 0, /* tp_weaklistoffset */
571 0, /* tp_iter */
572 0, /* tp_iternext */
573 DimList_methods, /* tp_methods */
574 0, /* tp_members */
575 DimList_getsetters, /* tp_getset */
576 0, /* tp_base */
577 0, /* tp_dict */
578 0, /* tp_descr_get */
579 0, /* tp_descr_set */
580 0, /* tp_dictoffset */
581 (initproc) DimList_init, /* tp_init */
582 0, /* tp_alloc */
583 DimList::new_stub, /* tp_new */
584 };
585
DimList_init(DimList * self,PyObject * args,PyObject * kwds)586 static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) {
587 PY_BEGIN
588 static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr};
589 mpy::handle len_or_dims = nullptr;
590 PyObject* name = nullptr;
591 if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast<char**>(kwlist), &len_or_dims, &name)) {
592 return -1;
593 }
594 self->init(mpy::object::borrow(name ? name : Py_None));
595 if (len_or_dims.ptr()) {
596 if(mpy::is_int(len_or_dims)) {
597 self->bind_len(mpy::to_int(len_or_dims));
598 } else if (mpy::is_sequence(len_or_dims)) {
599 mpy::sequence_view s(len_or_dims);
600 std::vector<mpy::obj<Dim>> dims;
601 size_t size = s.size();
602 dims.reserve(size);
603 for (size_t i = 0; i < size; ++i) {
604 auto r = s[i];
605 if (mpy::is_int(r)) {
606 dims.emplace_back(Dim::create(mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), mpy::to_int(r)));
607 } else {
608 dims.emplace_back(Dim::wrap(r));
609 }
610 }
611 self->set_dims(std::move(dims));
612 } else {
613 PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions");
614 return -1;
615 }
616 return 0;
617 }
618 return 0;
619 PY_END(-1);
620 }
621
622 // Tensor -----------------------------
623
624 PyTypeObject* TensorType = nullptr; // the python wrapper type.
625 mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise);
626
627 namespace{
628
_add_batch_dims(Arena & A,at::Tensor t,Slice<DimEntry> levels_)629 at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice<DimEntry> levels_) {
630 auto levels = Slice<DimEntry>();
631 levels.extend(A, levels_);
632 while (true) {
633 int64_t min_real_index = -1;
634 int64_t min_index = -1;
635 int64_t min_value = INT_MAX;
636 int64_t i = 0;
637 int64_t r = 0;
638 for (auto l : levels) {
639 if (!l.is_none()) {
640 if (!l.is_positional() && l.dim()->level_ < min_value) {
641 min_value = l.dim()->level_;
642 min_index = i;
643 min_real_index = r;
644 }
645 ++i;
646 }
647 ++r;
648 }
649 if (min_index == -1) {
650 return t;
651 }
652 auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value);
653 t = std::move(t2);
654 levels[min_real_index] = DimEntry();
655 }
656 }
657
658
659
660 struct DelayedOperator {
DelayedOperator__anon5de4c5480311::DelayedOperator661 DelayedOperator(mpy::object o, mpy::vector_args a)
662 : orig(std::move(o)), args(a) {
663 auto all = a.size();
664 // this will outlive the call so
665 // take ownership of temporaries
666 // in vector args
667 auto buf = new mpy::handle[all];
668 memcpy(buf, args.args, sizeof(mpy::handle)*all);
669 args.args = buf;
670 for (auto i : args.enumerate_all()) {
671 Py_INCREF(args.args[i].ptr());
672 }
673 Py_XINCREF(args.kwnames.ptr());
674 }
~DelayedOperator__anon5de4c5480311::DelayedOperator675 ~DelayedOperator() {
676 for (auto i : args.enumerate_all()) {
677 Py_DECREF(args[i].ptr());
678 }
679 if (args.has_keywords()) {
680 Py_XDECREF(args.kwnames.ptr());
681 }
682 delete [] args.args;
683 }
684 mpy::object orig;
685 mpy::vector_args args;
686 };
687
free_levels_dims(Slice<DimEntry> levels)688 void free_levels_dims(Slice<DimEntry> levels) {
689 for(auto e : levels) {
690 if (!e.is_positional()) {
691 mpy::object::steal(e.dim());
692 }
693 }
694 }
695 }
696
697 struct Tensor : public mpy::base<Tensor> {
698 private:
699 at::Tensor tensor_;
700 at::Tensor batchtensor_;
701 OwnedSlice<DimEntry> levels_;
702 bool has_device_;
703 std::unique_ptr<DelayedOperator> delayed_;
704 public:
705
tensorTensor706 at::Tensor& tensor(Arena& A) {
707 if (C10_UNLIKELY(!tensor_.defined())) {
708 AT_ASSERT(delayed_);
709 auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true));
710 tensor_ = t->tensor(A);
711 delayed_.reset();
712 // don't force creation of batch tensor if it wasn't alreay provided.
713 batchtensor_ = t->batchtensor_;
714 AT_ASSERT(levels() == t->levels());
715 }
716 return tensor_;
717 }
batchtensorTensor718 at::Tensor& batchtensor(Arena& A) {
719 if (C10_UNLIKELY(!batchtensor_.defined())) {
720 batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice());
721 }
722 return batchtensor_;
723 }
levelsTensor724 Slice<DimEntry> levels() {
725 return levels_.slice();
726 }
has_deviceTensor727 bool has_device() {
728 return has_device_;
729 }
delayedTensor730 DelayedOperator* delayed() {
731 return delayed_.get();
732 }
733 static PyTypeObject Type;
734
check_exactTensor735 static bool check_exact(mpy::handle v) {
736 return Py_TYPE(v.ptr()) == TensorType;
737 }
738
739
createTensor740 static mpy::obj<Tensor> create() {
741 if (!TensorType) {
742 TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr();
743 }
744 return Tensor::alloc(TensorType);
745 }
capture_levelsTensor746 void capture_levels(Slice<DimEntry> levels) {
747 // grab ownership of the dims inside levels
748 for (auto l : levels) {
749 if (!l.is_positional()) {
750 mpy::object::borrow(l.dim()).release();
751 }
752 }
753 levels_.set(levels, free_levels_dims);
754 }
755 static mpy::object from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device);
756 static mpy::obj<Tensor> create_delayed(mpy::object op, mpy::vector_args args, Slice<DimEntry> levels, bool has_device);
757 friend struct EnableAllLayers;
758 };
759
760 namespace{
761 // version in header does a unnecessary refcount +/-
maybeGetBatchedImpl(const at::Tensor & tensor)762 at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) {
763 if (at::functorch::isBatchedTensor(tensor)) {
764 return static_cast<at::functorch::BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
765 }
766 return nullptr;
767 }
768
unchecked_tensor_from(mpy::handle p)769 TensorRef unchecked_tensor_from(mpy::handle p) {
770 auto v = (THPVariable*) p.ptr();
771 return TensorRef(*v->cdata);
772 }
773
ndim_of_levels(Slice<DimEntry> levels)774 static int64_t ndim_of_levels(Slice<DimEntry> levels) {
775 int64_t r = 0;
776 for (auto l : levels) {
777 if (l.is_positional()) {
778 ++r;
779 }
780 }
781 return r;
782 }
783
784 struct TensorInfo {
785 TensorRef tensor;
786 Slice<DimEntry> levels;
787 bool has_device;
788 TensorRef batchedtensor;
ndim__anon5de4c5480411::TensorInfo789 int64_t ndim() const {
790 return ndim_of_levels(levels);
791 }
operator bool__anon5de4c5480411::TensorInfo792 operator bool() const {
793 return tensor;
794 }
795
create__anon5de4c5480411::TensorInfo796 static TensorInfo create(Arena& A, mpy::handle h, bool ensure_batched=true, bool ensure_present=true) {
797 if (Tensor::check_exact(h)) {
798 auto t = Tensor::unchecked_wrap(h);
799 return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()};
800 } else if (Dim::check_exact(h)) {
801 auto d = Dim::unchecked_wrap(h);
802 return TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()};
803 } else if (THPVariable_Check(h.ptr())) {
804 TensorRef t = unchecked_tensor_from(h);
805 Slice<DimEntry> levels;
806 for (auto i : irange(-t->dim(), 0)) {
807 levels.append(A, i);
808 }
809 return TensorInfo {t, levels, true, t};
810 } else {
811 if (ensure_present) {
812 mpy::raise_error(PyExc_ValueError, "expected a tensor object");
813 }
814 return TensorInfo {};
815 }
816 }
817
818
819 };
820
py_Tensor_from_positional(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)821 static PyObject* py_Tensor_from_positional(PyObject *self,
822 PyObject *const *args,
823 Py_ssize_t nargs,
824 PyObject *kwnames) {
825 Arena A;
826 PY_BEGIN
827 #define ARGS(_) _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device)
828 MPY_PARSE_ARGS_KWNAMES("OOp", ARGS)
829 #undef ARGS
830
831 if (!THPVariable_Check(tensor.ptr())) {
832 mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?");
833 }
834
835 Slice<DimEntry> levels;
836 mpy::sequence_view sq(py_levels);
837 for (auto i : sq.enumerate()) {
838 mpy::object v = sq[i];
839 if (mpy::is_int(v)) {
840 auto vi = mpy::to_int(v);
841 levels.append(A, vi);
842 } else {
843 auto dim = Dim::wrap(std::move(v));
844 mpy::hdl<Dim> hdim = dim;
845 levels.append(A, hdim);
846 }
847 }
848 return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release();
849 PY_END(nullptr)
850 }
851 }
852
from_positional(Arena & A,at::Tensor tensor,Slice<DimEntry> levels,bool has_device)853 mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice<DimEntry> levels, bool has_device) {
854 size_t seen_dims = 0;
855 int last = 0;
856 //auto sz = tensor.sizes();
857 for (auto i : levels.enumerate()) {
858 auto l = levels[i];
859 if (l.is_positional()) {
860 AT_ASSERT(last == 0 || last + 1 == l.position());
861 last = l.position();
862 } else {
863 mpy::object::borrow(l.dim()).release();
864 //AT_ASSERT(sz[i] == l.dim()->size());
865 ++seen_dims;
866 }
867 }
868 AT_ASSERT(last == 0 || last == -1);
869 if (!seen_dims) {
870 return mpy::object::steal(THPVariable_Wrap(std::move(tensor)));
871 }
872
873 mpy::obj<Tensor> self = Tensor::create();
874 self->tensor_ = std::move(tensor);
875 AT_ASSERT(self->tensor_.dim() == levels.size());
876 self->levels_.set(levels, free_levels_dims);
877 self->has_device_ = has_device;
878 mpy::object r = std::move(self);
879 return r;
880 }
881
882
create_delayed(mpy::object op,mpy::vector_args args,Slice<DimEntry> levels,bool has_device)883 mpy::obj<Tensor> Tensor::create_delayed(mpy::object op, mpy::vector_args args, Slice<DimEntry> levels, bool has_device) {
884 mpy::obj<Tensor> self = Tensor::create();
885 self->capture_levels(levels);
886 self->has_device_ = has_device;
887 self->delayed_ = std::make_unique<DelayedOperator>(std::move(op), args);
888 return self;
889 }
890
891 namespace{
slice_to_list(Slice<mpy::handle> h)892 mpy::list slice_to_list(Slice<mpy::handle> h) {
893 mpy::list lst(h.size());
894 for (auto i : h.enumerate()) {
895 lst.set(i, mpy::object::borrow(h[i]));
896 }
897 return lst;
898 }
899
slice_to_tuple(Slice<mpy::handle> h)900 mpy::tuple slice_to_tuple(Slice<mpy::handle> h) {
901 mpy::tuple lst(h.size());
902 for (auto i : h.enumerate()) {
903 lst.set(i, mpy::object::borrow(h[i]));
904 }
905 return lst;
906 }
907
908 enum UType {
909 U_ELEM,
910 U_TUPLE_LIKE,
911 U_DICT,
912 };
913
914 struct Unflatten {
operator ()__anon5de4c5480511::Unflatten915 mpy::object operator()(Slice<mpy::handle>& elements) {
916 mpy::object r;
917 switch (type) {
918 case U_ELEM: {
919 r = mpy::object::borrow(elements[0]);
920 elements = elements.slice(1);
921 } break;
922 case U_TUPLE_LIKE: {
923 mpy::tuple tup(children.size());
924 for (auto i : children.enumerate()) {
925 tup.set(i, children[i](elements));
926 }
927 r = obj.call(tup);
928 } break;
929 case U_DICT: {
930 r = mpy::object::checked_steal(PyDict_New());
931 mpy::dict_view rv(r);
932 mpy::dict_view d(obj);
933 Py_ssize_t pos = 0;
934 mpy::handle k, v;
935 for (int i = 0; d.next(&pos, &k, &v); ++i) {
936 rv.set(k, children[i](elements));
937 }
938 } break;
939 }
940 return r;
941 }
942 UType type;
943 mpy::handle obj;
944 Slice<Unflatten> children;
945 };
946
tree_flatten(Arena & A,mpy::handle agg,Slice<mpy::handle> & flat_elements)947 Unflatten tree_flatten(Arena& A, mpy::handle agg, Slice<mpy::handle>& flat_elements) {
948 Slice<Unflatten> c;
949 UType utype;
950 mpy::handle obj;
951 if (mpy::list_view::check(agg)) {
952 obj = agg.type();
953 utype = U_TUPLE_LIKE;
954 mpy::list_view l(agg);
955 for (auto i : l.enumerate()) {
956 c.append(A, tree_flatten(A, l[i], flat_elements));
957 }
958 } else if (mpy::tuple_view::check(agg)) {
959 obj = agg.type();
960 utype = U_TUPLE_LIKE;
961 // includes named tuples
962 mpy::tuple_view l(agg);
963 for (auto i : l.enumerate()) {
964 c.append(A, tree_flatten(A, l[i], flat_elements));
965 }
966 } else if (mpy::dict_view::check(agg)) {
967 utype = U_DICT;
968 mpy::dict_view d(agg);
969 obj = agg;
970 Py_ssize_t pos = 0;
971 mpy::handle k, v;
972 while (d.next(&pos, &k, &v)) {
973 c.append(A, tree_flatten(A, v, flat_elements));
974 }
975 } else {
976 utype = U_ELEM;
977 flat_elements.append(A, agg);
978 }
979 return Unflatten {utype, obj, c};
980 }
981
982 struct UnflattenVectorArgs {
operator ()__anon5de4c5480511::UnflattenVectorArgs983 mpy::vector_args operator()(Arena& A, Slice<mpy::handle>& elements) {
984 if (!had_nested) {
985 auto args = elements.begin();
986 elements = Slice<mpy::handle>();
987 return mpy::vector_args(args, nargs, kwnames);
988 }
989 Slice<mpy::handle> args;
990 for (auto u : children) {
991 args.append(A, A.autorelease(u(elements)));
992 }
993 return mpy::vector_args(args.begin(), nargs, kwnames);
994 }
995 Slice<Unflatten> children;
996 Py_ssize_t nargs;
997 mpy::handle kwnames;
998 bool had_nested;
999 };
1000
tree_flatten(Arena & A,mpy::vector_args args,Slice<mpy::handle> & flat_elements)1001 UnflattenVectorArgs tree_flatten(Arena& A, mpy::vector_args args, Slice<mpy::handle>& flat_elements) {
1002 UnflattenVectorArgs r;
1003 r.kwnames = args.kwnames;
1004 r.nargs = args.nargs;
1005 r.had_nested = false;
1006 auto N = args.size();
1007 for(auto i : irange(N)) {
1008 auto typ = Py_TYPE(args[i].ptr());
1009 // fast checks that this thing isn't something that is nested.
1010 bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType;
1011 if (!is_element) {
1012 flat_elements.extend(A, args.args, args.args + i);
1013 for (auto j : irange(i)) {
1014 (void)j;
1015 r.children.append(A, Unflatten {U_ELEM});
1016 }
1017 for (auto j : irange(i, N)) {
1018 r.children.append(A, tree_flatten(A, args[j], flat_elements));
1019 if (r.children.back().type != U_ELEM) {
1020 r.had_nested = true;
1021 }
1022 }
1023 return r;
1024 }
1025 }
1026 flat_elements.extend(A, args.args, args.args + N);
1027 return r;
1028 }
1029
1030
1031 struct UnflattenArena {
1032 Arena A;
1033 Unflatten unflatten;
1034 };
1035
py_unflatten(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1036 PyObject* py_unflatten(PyObject *self,
1037 PyObject *const *args,
1038 Py_ssize_t nargs,
1039 PyObject *kwnames) {
1040 PY_BEGIN
1041 #define ARGS(_) _(mpy::handle, ns)
1042 MPY_PARSE_ARGS_KWNAMES("O", ARGS)
1043 #undef ARGS
1044 mpy::sequence_view sv(ns);
1045 // because we do not have a autorelase pool yet...
1046 Arena A;
1047 Slice<mpy::handle> slice;
1048 mpy::handle Tuple = (PyObject*) &PyTuple_Type;
1049 auto inputs = Tuple.call(ns);
1050 mpy::tuple_view tv(inputs);
1051 for (auto i : tv.enumerate()) {
1052 slice.append(A, tv[i]);
1053 }
1054 auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena");
1055 auto r = AA->unflatten(slice).release();
1056 AT_ASSERT(r != nullptr);
1057 return r;
1058 PY_END(nullptr)
1059 }
1060
1061 PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS};
1062
free_unflatten_arena(PyObject * pc)1063 void free_unflatten_arena(PyObject * pc) {
1064 delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena");
1065 }
1066
py_tree_flatten(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1067 PyObject* py_tree_flatten(PyObject *self,
1068 PyObject *const *args,
1069 Py_ssize_t nargs,
1070 PyObject *kwnames) {
1071 PY_BEGIN
1072 #define ARGS(_) _(mpy::handle, tree)
1073 MPY_PARSE_ARGS_KWNAMES("O", ARGS)
1074 #undef ARGS
1075 auto A = new UnflattenArena;
1076 Slice<mpy::handle> elements;
1077 A->unflatten = tree_flatten(A->A, tree, elements);
1078 auto cap = mpy::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena));
1079 auto unflatten = mpy::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release()));
1080 mpy::tuple r(2);
1081 r.set(0, slice_to_list(elements));
1082 r.set(1, std::move(unflatten));
1083 return r.release();
1084 PY_END(nullptr)
1085 }
1086
1087
1088
tree_map(Arena & A,const std::function<mpy::handle (mpy::handle)> & fn,mpy::handle agg)1089 mpy::object tree_map(Arena& A, const std::function<mpy::handle(mpy::handle)>& fn, mpy::handle agg) {
1090 Slice<mpy::handle> elements;
1091 auto unflatten = tree_flatten(A, agg, elements);
1092 for (auto i : elements.enumerate()) {
1093 elements[i] = fn(elements[i]);
1094 }
1095 return unflatten(elements);
1096 }
1097
1098 // prereq: isinstance(h, _Tensor)
_Tensor_ndim(mpy::handle h)1099 int64_t _Tensor_ndim(mpy::handle h) {
1100 if (Tensor::check(h)) {
1101 int64_t r = 0;
1102 for (auto l : Tensor::unchecked_wrap(h)->levels()) {
1103 if (l.is_positional()) {
1104 ++r;
1105 }
1106 }
1107 return r;
1108 }
1109 // Dim or DelayedMulTensor
1110 return 0;
1111 }
1112
handle_from_tensor(Arena & A,TensorRef t)1113 mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
1114 // fast case: tensor is live in python
1115 std::optional<PyObject*> mb_obj =
1116 t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
1117 if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
1118 return *mb_obj;
1119 }
1120 return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
1121 }
1122 }
1123 struct EnableAllLayers {
EnableAllLayersEnableAllLayers1124 EnableAllLayers(Arena& A, Slice<DimEntry> levels) {
1125 std::vector<std::pair<int64_t, int64_t>> layers;
1126 layers.reserve(levels.size());
1127 for (auto l : levels) {
1128 if (!l.is_positional()) {
1129 auto d = l.dim();
1130 levels_to_dim_.append(A, d);
1131 }
1132 }
1133 std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](mpy::hdl<Dim> lhs, mpy::hdl<Dim> rhs) { return lhs->level_ < rhs->level_;});
1134
1135 for (auto i : levels_to_dim_.enumerate()) {
1136 auto batch_size = levels_to_dim_[i]->size();
1137 auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different);
1138 if (i == 0) {
1139 levels_start_ = level;
1140 }
1141 }
1142 }
1143
~EnableAllLayersEnableAllLayers1144 ~EnableAllLayers() {
1145 auto to_remove = levels_start_ + levels_to_dim_.size() - 1;
1146 for (auto i : levels_to_dim_.enumerate()) {
1147 AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i);
1148 }
1149 }
1150
from_batchedEnableAllLayers1151 mpy::obj<Tensor> from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) {
1152 Slice<DimEntry> levels;
1153 for (auto i : irange(-batchedtensor.dim(), 0)) {
1154 levels.append(A, i);
1155 }
1156 TensorRef tensor;
1157 at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor);
1158 while(true) {
1159 auto level = impl->level();
1160 AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size());
1161 mpy::hdl<Dim> dim = levels_to_dim_[level - levels_start_].ptr();
1162 levels.insert(A, impl->bdim(), dim);
1163 at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value());
1164 if (!nimpl) {
1165 tensor = impl->value();
1166 break;
1167 }
1168 impl = nimpl;
1169 }
1170
1171 mpy::obj<Tensor> self = Tensor::create();
1172 // grab ownership of the tensors
1173 self->tensor_ = *tensor;
1174 self->batchtensor_ = std::move(batchedtensor);
1175 self->has_device_ = has_device;
1176 self->capture_levels(levels);
1177 return self;
1178 }
inplace_update_layersEnableAllLayers1179 void inplace_update_layers(TensorRef batchtensor, Slice<DimEntry> levels) {
1180 // XXX - requires a patch to functorch to att set_level
1181 auto impl = maybeGetBatchedImpl(*batchtensor);
1182 for (auto i : levels_to_dim_.reversed_enumerate()) {
1183 if (!impl) {
1184 break;
1185 }
1186 if (levels.contains(levels_to_dim_[i])) {
1187 impl->_unsafe_set_level(levels_start_ + i);
1188 impl = maybeGetBatchedImpl(impl->value());
1189
1190 }
1191 }
1192 }
1193 private:
1194 int64_t levels_start_{};
1195 Slice<mpy::hdl<Dim>> levels_to_dim_;
1196 };
1197
1198 namespace{
_match_levels(Arena & A,TensorRef v,Slice<DimEntry> from_levels,Slice<DimEntry> to_levels,bool drop_levels=false)1199 TensorRef _match_levels(Arena& A, TensorRef v, Slice<DimEntry> from_levels, Slice<DimEntry> to_levels, bool drop_levels=false) {
1200 if (from_levels == to_levels) {
1201 return v;
1202 }
1203 // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0.
1204 at::IntArrayRef sz = v->sizes();
1205 at::IntArrayRef sd = v->strides();
1206 AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size());
1207 Slice<int64_t> nsz;
1208 Slice<int64_t> nsd;
1209 for (auto l : to_levels) {
1210 auto oidx = from_levels.index(l);
1211 if (!oidx) {
1212 nsz.append(A, l.is_positional() ? 1 : l.dim()->size());
1213 nsd.append(A, 0);
1214 } else {
1215 auto idx = *oidx;
1216 nsz.append(A, sz[idx]);
1217 nsd.append(A, sd[idx]);
1218 }
1219 }
1220 return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset()));
1221 }
1222 }
run_torch_function(Arena & A,mpy::handle orig,mpy::vector_args args,bool is_pointwise)1223 mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) {
1224 if (!pointwise_optimize) {
1225 is_pointwise = false;
1226 }
1227 // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n";
1228
1229 Slice<mpy::hdl<Dim>> all_dims;
1230 Slice<mpy::handle> flat_args;
1231 auto unflatten_args = tree_flatten(A, args, flat_args);
1232 TensorRef device_holding_tensor;
1233
1234 Slice<TensorInfo> infos;
1235 Slice<DimEntry> result_levels;
1236 for (auto f : flat_args) {
1237 infos.append(A, TensorInfo::create(A, f, !is_pointwise, false));
1238 if (infos.back()) {
1239 TensorInfo& info = infos.back();
1240 AT_ASSERT(is_pointwise || info.batchedtensor);
1241 if (!device_holding_tensor && info.has_device) {
1242 device_holding_tensor = infos.back().tensor;
1243 }
1244 for (auto l : info.levels) {
1245 if (!result_levels.contains(l)) {
1246 result_levels.append(A, l);
1247 }
1248 }
1249 }
1250 }
1251
1252 if (is_pointwise) {
1253 for (auto i : flat_args.enumerate()) {
1254 if (infos[i]) {
1255 TensorRef tensor = infos[i].tensor;
1256 if (device_holding_tensor && !infos[i].has_device) {
1257 tensor = A.autorelease(tensor->to(device_holding_tensor->device()));
1258 }
1259 auto ml = _match_levels(A, tensor, infos[i].levels, result_levels);
1260 flat_args[i] = handle_from_tensor(A, std::move(ml));
1261 }
1262 }
1263
1264 Slice<mpy::handle> flat_it = flat_args;
1265 mpy::vector_args uargs = unflatten_args(A, flat_it);
1266
1267 mpy::object result = orig.call_vector(uargs);
1268
1269 // fast wrap for normal case where operator just returns a tensor.
1270 if (THPVariable_Check(result.ptr())) {
1271 return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor);
1272 }
1273 auto wrap = [&](mpy::handle h) {
1274 if (THPVariable_Check(h.ptr())){
1275 return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor));
1276 }
1277 return h;
1278 };
1279 return tree_map(A, wrap, result);
1280 } else {
1281 // std::cout << orig << " calling functorch...\n";
1282 // std::cout << "rl: " << result_levels << "\n";
1283 EnableAllLayers guard(A, result_levels);
1284 for (auto i : flat_args.enumerate()) {
1285 if (infos[i]) {
1286 TensorRef batched = infos[i].batchedtensor;
1287 if (device_holding_tensor && !infos[i].has_device) {
1288 batched = A.autorelease(batched->to(device_holding_tensor->device()));
1289 }
1290 guard.inplace_update_layers(batched, infos[i].levels);
1291 flat_args[i] = handle_from_tensor(A, batched);
1292 }
1293 }
1294 Slice<mpy::handle> flat_it = flat_args;
1295 mpy::vector_args uargs = unflatten_args(A, flat_it);
1296 AT_ASSERT(flat_it.size() == 0);
1297 mpy::object result = orig.call_vector(uargs);
1298 auto wrap = [&](mpy::handle h) {
1299 if (THPVariable_Check(h.ptr())) {
1300 return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor));
1301 }
1302 return h;
1303 };
1304 if (THPVariable_Check(result.ptr())) {
1305 return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor);
1306 }
1307 return tree_map(A, wrap, result);
1308 }
1309 }
1310
1311 namespace{
1312
__torch_function__(Arena & A,mpy::handle orig,mpy::vector_args args,bool is_pointwise)1313 mpy::object __torch_function__(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) {
1314 if (orig == torch_Tensor___mul__) {
1315 AT_ASSERT(args.nargs == 2 && !args.has_keywords());
1316 auto lhs = args[0];
1317 auto rhs = args[1];
1318 if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) {
1319 bool has_device = false;
1320 Slice<DimEntry> levels;
1321 for (auto i : args.enumerate_positional()) {
1322 auto t = TensorInfo::create(A, args[i], false);
1323 // something like a mask * rhs, which matrix multiplies don't correctly promote
1324 if (!t.tensor->is_floating_point()) {
1325 return run_torch_function(A, orig, args, is_pointwise);
1326 }
1327 has_device = has_device || t.has_device;
1328 for (auto l : t.levels) {
1329 if (!levels.contains(l)) {
1330 levels.append(A, l);
1331 }
1332 }
1333 }
1334 // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n";
1335 return Tensor::create_delayed(mpy::object::borrow(orig), args, levels, has_device);
1336 }
1337 }
1338 return run_torch_function(A, orig, args, is_pointwise);
1339 }
1340
as_vector_args(Arena & A,mpy::handle args,mpy::handle kwargs)1341 mpy::vector_args as_vector_args(Arena& A, mpy::handle args, mpy::handle kwargs) {
1342 auto pos_args = (mpy::handle*) &PyTuple_GET_ITEM(args.ptr(), 0);
1343 auto pos_n = PyTuple_GET_SIZE(args.ptr());
1344 if (!kwargs.ptr()) {
1345 return mpy::vector_args(pos_args, pos_n, nullptr);
1346 }
1347 Slice<mpy::handle> all_args;
1348 Slice<mpy::handle> kwnames;
1349 all_args.extend(A, pos_args, pos_args + pos_n);
1350 mpy::dict_view dv(kwargs);
1351 Py_ssize_t pos = 0;
1352 mpy::handle key, value;
1353 while (dv.next(&pos, &key, &value)) {
1354 all_args.append(A, value);
1355 kwnames.append(A, key);
1356 }
1357 return mpy::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames)));
1358 }
1359
py___torch_function__(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1360 PyObject* py___torch_function__(PyObject *self,
1361 PyObject *const *args,
1362 Py_ssize_t nargs,
1363 PyObject *kwnames) {
1364 Arena A;
1365 PY_BEGIN
1366 maybeInitializeGlobals();
1367 AT_ASSERT(nargs == 4 || nargs == 5);
1368 auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr);
1369 bool is_pointwise = pointwise.contains(args[1]);
1370 return __torch_function__(A, args[1], std::move(va), is_pointwise).release();
1371 PY_END(nullptr)
1372 }
1373
levels_to_tuple(Slice<DimEntry> slice)1374 mpy::object levels_to_tuple(Slice<DimEntry> slice) {
1375 mpy::tuple t(slice.size());
1376 for (auto i : slice.enumerate()) {
1377 t.set(i, slice[i].is_positional() ? mpy::from_int(slice[i].position()) : mpy::object::borrow(slice[i].dim()));
1378 }
1379 mpy::object r = std::move(t);
1380 return r;
1381 }
1382
Tensor_ndim(Tensor * self,void *)1383 PyObject* Tensor_ndim(Tensor* self, void*) {
1384 Py_ssize_t i = 0;
1385 for (auto l : self->levels()) {
1386 if (l.is_positional()) {
1387 ++i;
1388 }
1389 }
1390 return mpy::from_int(i).release();
1391 }
1392
1393 PyGetSetDef Tensor_getsetters[] = {
__anon5de4c5480b02() 1394 {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_bool(((Tensor*)self)->has_device()).release(); }, NULL},
__anon5de4c5480c02() 1395 {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* {
1396 Arena A;
1397 return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL},
__anon5de4c5480d02() 1398 {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* {
1399 Arena A;
1400 return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL},
__anon5de4c5480e02() 1401 {"_levels", (getter) [](PyObject* self, void*) -> PyObject* {
1402 PY_BEGIN
1403 return levels_to_tuple(((Tensor*)self)->levels()).release();
1404 PY_END(nullptr)
1405 }},
1406 {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL},
1407 {NULL} /* Sentinel */
1408 };
1409
1410 PyMethodDef Tensor_methods[] = {
1411 {NULL, NULL, 0, NULL} /* Sentinel */
1412 };
1413 }
1414
1415
1416 PyTypeObject Tensor::Type = {
1417 PyVarObject_HEAD_INIT(NULL, 0)
1418 "_C.Tensor", /* tp_name */
1419 sizeof(Tensor), /* tp_basicsize */
1420 0, /* tp_itemsize */
1421 Tensor::dealloc_stub, /* tp_dealloc */
1422 0, /* tp_vectorcall_offset */
1423 0, /* tp_getattr */
1424 0, /* tp_setattr */
1425 0, /* tp_as_async */
1426 0, /* tp_repr */
1427 0, /* tp_as_number */
1428 0, /* tp_as_sequence */
1429 0, /* tp_as_mapping */
1430 0, /* tp_hash */
1431 0, /* tp_call */
1432 0, /* tp_str */
1433 0, /* tp_getattro */
1434 0, /* tp_setattro */
1435 0, /* tp_as_buffer */
1436 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */
1437 "Tensor Object", /* tp_doc */
1438 0, /* tp_traverse */
1439 0, /* tp_clear */
1440 0, /* tp_richcompare */
1441 0, /* tp_weaklistoffset */
1442 0, /* tp_iter */
1443 0, /* tp_iternext */
1444 Tensor_methods, /* tp_methods */
1445 0, /* tp_members */
1446 Tensor_getsetters, /* tp_getset */
1447 0, /* tp_base */
1448 0, /* tp_dict */
1449 0, /* tp_descr_get */
1450 0, /* tp_descr_set */
1451 0, /* tp_dictoffset */
1452 0, /* tp_init */
1453 0, /* tp_alloc */
1454 Tensor::new_stub, /* tp_new */
1455 };
1456
1457
1458 // dim() --------------------
1459
relevant_op(_Py_CODEUNIT c)1460 static bool relevant_op(_Py_CODEUNIT c) {
1461 switch(c) {
1462 case STORE_NAME:
1463 case STORE_GLOBAL:
1464 case STORE_FAST:
1465 case STORE_DEREF:
1466 return true;
1467 default:
1468 return false;
1469 }
1470 }
1471
create_dim(mpy::object name,mpy::handle size)1472 static mpy::object create_dim(mpy::object name, mpy::handle size) {
1473 auto d = Dim::create(std::move(name));
1474 if (!mpy::is_none(size)) {
1475 d->set_size(mpy::to_int(size));
1476 }
1477 return std::move(d);
1478 }
1479
create_dimlist(mpy::object name,mpy::handle size)1480 static mpy::object create_dimlist(mpy::object name, mpy::handle size) {
1481 auto d = DimList::create(std::move(name));
1482 if (!mpy::is_none(size)) {
1483 if (mpy::is_int(size)) {
1484 d->bind_len(mpy::to_int(size));
1485 } else {
1486 mpy::sequence_view s(size);
1487 d->bind_len(s.size());
1488 for (auto i : irange(d->size())) {
1489 d->dims_[i]->set_size(mpy::to_int(s[i]));
1490 }
1491 }
1492 }
1493 return std::move(d);
1494 }
1495
1496
1497
1498 // Python wrappers that make new reflection primitives available for older runtimes
1499 #if !(IS_PYTHON_3_11_PLUS)
1500 #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code))
1501 #endif
1502
1503 namespace{
1504 struct PyInstDecoder {
PyInstDecoder__anon5de4c5480f11::PyInstDecoder1505 PyInstDecoder(PyCodeObject* code_object, int lasti)
1506 : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {}
1507 // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols
1508 // See https://github.com/pytorch/pytorch/issues/93854
next__anon5de4c5480f11::PyInstDecoder1509 void next() {
1510 #if IS_PYTHON_3_11_PLUS
1511 offset_ += _PyOpcode_Caches[opcode()];
1512 #endif
1513 offset_ += 1;
1514 }
opcode__anon5de4c5480f11::PyInstDecoder1515 int opcode() {
1516 auto r = _Py_OPCODE(code_[offset_]);
1517 #if IS_PYTHON_3_11_PLUS
1518 r = _PyOpcode_Deopt[r];
1519 #endif
1520 return r;
1521 }
oparg__anon5de4c5480f11::PyInstDecoder1522 int oparg() {
1523 return _Py_OPARG(code_[offset_]);
1524 }
1525
name__anon5de4c5480f11::PyInstDecoder1526 mpy::object name() {
1527 mpy::object names;
1528 switch(opcode()) {
1529 case STORE_NAME:
1530 case STORE_GLOBAL:
1531 names = mpy::object::borrow(code_object_->co_names);
1532 break;
1533 case STORE_FAST:
1534 names = mpy::object::steal(PyCode_GetVarnames(code_object_));
1535 break;
1536 case STORE_DEREF:
1537 names = mpy::object::steal(PyCode_GetCellvars(code_object_));
1538 break;
1539 default:
1540 return mpy::object();
1541 }
1542 return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg()));
1543 }
1544 private:
1545 PyCodeObject* code_object_;
1546 _Py_CODEUNIT* code_;
1547 int offset_;
1548 };
1549
1550 template<mpy::object (*create_object)(mpy::object, mpy::handle)>
_dims(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1551 static PyObject* _dims(PyObject *self,
1552 PyObject *const *args,
1553 Py_ssize_t nargs,
1554 PyObject *kwnames) {
1555 PY_BEGIN
1556 Py_ssize_t specified_ndims = -1;
1557 Py_ssize_t found_ndims = 0;
1558 Py_ssize_t sizes = -1;
1559 mpy::handle n = Py_None;
1560 mpy::handle py_sizes = Py_None;
1561
1562 if (nargs || kwnames) {
1563 mpy::vector_args va(args, nargs, kwnames);
1564 va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0);
1565 if (!mpy::is_none(py_sizes)) {
1566 sizes = mpy::sequence_view(py_sizes).size();
1567 specified_ndims = sizes;
1568 }
1569 if (!mpy::is_none(n)) {
1570 specified_ndims = mpy::to_int(n);
1571 }
1572 }
1573
1574 PyThreadState* state = PyThreadState_GET();
1575 auto f = mpy::obj<PyFrameObject>::steal(PyThreadState_GetFrame(state));
1576 auto c = mpy::obj<PyCodeObject>::steal(PyFrame_GetCode(f.ptr()));
1577 auto lasti = PyFrame_GetLasti(f.ptr());
1578 auto decoder = PyInstDecoder(c.ptr(), lasti);
1579 #if IS_PYTHON_3_11_PLUS
1580 // When py3.11 adapts bytecode lasti points to the precall
1581 // rather than the call instruction after it
1582 if (decoder.opcode() == PRECALL) {
1583 decoder.next();
1584 }
1585 #endif
1586 decoder.next();
1587
1588 if (relevant_op(decoder.opcode())) {
1589 found_ndims = 1;
1590 } else if (decoder.opcode() == UNPACK_SEQUENCE) {
1591 found_ndims = decoder.oparg();
1592 decoder.next();
1593 }
1594
1595 if (specified_ndims == -1) {
1596 if (found_ndims == 0) {
1597 mpy::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified");
1598 }
1599 specified_ndims = found_ndims;
1600 }
1601 if (found_ndims != specified_ndims) {
1602 found_ndims = 0; // avoid taking the wrong names for dimensions
1603 }
1604
1605 auto genobject = [&](int i) -> mpy::object {
1606 mpy::object name;
1607 if (i < found_ndims) {
1608 name = decoder.name();
1609 }
1610 if (!name.ptr()) {
1611 name = mpy::unicode_from_format("d%d", i);
1612 found_ndims = 0; // once we fail at finding a name, we can find any more
1613 } else {
1614 decoder.next();
1615 }
1616 return create_object(std::move(name), sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None));
1617 };
1618 if (sizes != -1 && sizes != specified_ndims) {
1619 mpy::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes));
1620 }
1621 if (specified_ndims == 1) {
1622 return genobject(0).release();
1623 }
1624 mpy::tuple result(specified_ndims);
1625 for (int i = 0; i < specified_ndims; ++i) {
1626 result.set(i, genobject(i));
1627 }
1628 return result.release();
1629 PY_END(nullptr)
1630 }
1631
1632 struct DotPart {
1633 Slice<DimEntry> dims;
1634 size_t total_size = 1;
append__anon5de4c5480f11::DotPart1635 void append(Arena& A, mpy::hdl<Dim> d) {
1636 total_size *= d->size();
1637 dims.append(A, d);
1638 }
1639 };
1640
1641 template<typename T>
as_array_ref(Slice<T> t)1642 static at::ArrayRef<T> as_array_ref(Slice<T> t) {
1643 return at::ArrayRef<T>(t.begin(), t.end());
1644 }
1645
dot_prepare(Arena & A,std::initializer_list<DotPart> parts,const TensorInfo & t)1646 static TensorRef dot_prepare(Arena& A, std::initializer_list<DotPart> parts, const TensorInfo& t) {
1647 Slice<DimEntry> new_levels;
1648 bool needs_reshape = false;
1649 for (auto p : parts) {
1650 if (p.dims.size() != 1) {
1651 needs_reshape = true;
1652 }
1653 new_levels.extend(A, p.dims);
1654 }
1655 auto r = _match_levels(A, t.tensor, t.levels, new_levels, true);
1656 if (!needs_reshape) {
1657 return r;
1658 }
1659 Slice<int64_t> view;
1660 for (auto p : parts) {
1661 view.append(A, p.total_size);
1662 }
1663 return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end())));
1664 }
1665
dot_finish(Arena & A,std::initializer_list<DotPart> parts,at::Tensor r)1666 static mpy::object dot_finish(Arena& A, std::initializer_list<DotPart> parts, at::Tensor r) {
1667 Slice<DimEntry> result_levels;
1668 bool needs_reshape = false;
1669 for (auto p : parts) {
1670 if (p.dims.size() != 1) {
1671 needs_reshape = true;
1672 }
1673 result_levels.extend(A, p.dims);
1674 }
1675 if (needs_reshape) {
1676 Slice<int64_t> new_size;
1677 for (auto l : result_levels) {
1678 new_size.append(A, l.dim()->size());
1679 }
1680 r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end()));
1681 }
1682 return Tensor::from_positional(A, std::move(r), result_levels, true);
1683 }
1684
1685
1686
dot(Arena & A,TensorInfo lhs,TensorInfo rhs,Slice<DimEntry> sum)1687 static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice<DimEntry> sum) {
1688 auto lhs_strides = lhs.tensor->strides();
1689 auto rhs_strides = rhs.tensor->strides();
1690
1691 DotPart lro_dims;
1692 DotPart lo_dims;
1693 DotPart ro_dims;
1694 DotPart lr_dims;
1695
1696 auto insert_dim = [&] (mpy::hdl<Dim> d, std::optional<int> lhs_idx, std::optional<int> rhs_idx) {
1697 bool reduced = sum.contains(d);
1698 int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0;
1699 int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0;
1700 if (reduced) {
1701 // lr
1702 lr_dims.append(A, d);
1703 } else {
1704 if ((lhs_stride == 0) == (rhs_stride == 0)) {
1705 // lro
1706 lro_dims.append(A, d);
1707 } else if (lhs_stride != 0) {
1708 // lo
1709 lo_dims.append(A, d);
1710 } else {
1711 AT_ASSERT(rhs_stride != 0);
1712 ro_dims.append(A, d);
1713 }
1714 }
1715 };
1716
1717
1718 auto rhs_seen = A.allocate<bool>(rhs.levels.size());
1719 std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false);
1720
1721 for (auto i : lhs.levels.enumerate()) {
1722 auto d = lhs.levels[i];
1723 auto rhs_idx = rhs.levels.index(d);
1724 if (rhs_idx) {
1725 rhs_seen[*rhs_idx] = true;
1726 }
1727 insert_dim(d.dim(), i, rhs_idx);
1728 }
1729
1730 for (auto i : rhs.levels.enumerate()) {
1731 if (rhs_seen[i]) {
1732 continue;
1733 }
1734 auto d = rhs.levels[i];
1735 insert_dim(d.dim(), std::nullopt, i);
1736 }
1737
1738 if (lr_dims.dims.size() != sum.size()) {
1739 for (auto & d : sum) {
1740 if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) {
1741 mpy::raise_error(DimensionBindError(), "summing over non-existant dimension %S", d.dim().ptr());
1742 }
1743 }
1744 }
1745
1746 // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n";
1747 // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n";
1748
1749 // no batch, just call mm
1750 if (lro_dims.dims.size() != 0) {
1751 auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs);
1752 auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs);
1753 return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_));
1754 } else {
1755 auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs);
1756 auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs);
1757 return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_));
1758 }
1759
1760 }
1761
test_c(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1762 static PyObject* test_c(PyObject *self,
1763 PyObject *const *args,
1764 Py_ssize_t nargs,
1765 PyObject *kwnames) {
1766 PY_BEGIN
1767
1768 Arena A;
1769 Slice<int> s(A, 3, 4, 5);
1770 AT_ASSERT(s.size() == 3 && s.capacity() == 8);
1771 AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5);
1772 s.append(A, 6);
1773 AT_ASSERT(s[3] == 6);
1774 for(int i : irange(10)) {
1775 s.append(A, i);
1776 }
1777 AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16);
1778
1779 Slice<int> s2(A, -1, -2, -3);
1780 AT_ASSERT(s2[1] == -2 && s[0] == 3);
1781
1782 auto ss = s.slice(1,2);
1783 AT_ASSERT(ss.size() == 1);
1784 AT_ASSERT(ss[0] == 4);
1785 AT_ASSERT(ss.capacity() == 1);
1786 ss.append(A, -4);
1787 AT_ASSERT(ss.size() == 2 && ss[1] == -4);
1788 ss[0] = 3;
1789 AT_ASSERT(s[1] == 4);
1790
1791 s.insert(A, s.slice(1, 4), ss);
1792 AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0);
1793
1794 auto sz = s.size();
1795 s.insert(A, s.slice(1, 1), 4);
1796 AT_ASSERT(s[1] == 4 && sz + 1 == s.size());
1797
1798
1799 Slice<int> d(A, 0, 1, 2, 3, 4);
1800
1801 Slice<int> b(A, 0, 1, 2, 3, 4);
1802 b.insert(A, b.slice(1,1), d);
1803 AT_ASSERT(b.size() == 10);
1804 AT_ASSERT(b[1] == 0);
1805 AT_ASSERT(b[5] == 4);
1806 AT_ASSERT(b.back() == 4);
1807
1808 Py_RETURN_NONE;
1809
1810 PY_END(nullptr);
1811 }
1812
1813
order(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1814 static PyObject* order(PyObject *_,
1815 PyObject *const *args,
1816 Py_ssize_t nargs,
1817 PyObject *kwnames) {
1818 Arena A;
1819 PY_BEGIN
1820 if (kwnames) {
1821 mpy::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames);
1822 }
1823 AT_ASSERT(nargs-- > 0);
1824 Slice<DimEntry> orig_levels;
1825 Slice<DimEntry> levels;
1826 TensorRef data;
1827 mpy::handle self = args++[0];
1828 bool has_device;
1829 if (Tensor::check_exact(self)) {
1830 auto t = Tensor::unchecked_wrap(self);
1831 orig_levels = t->levels();
1832 data = t->tensor(A);
1833 has_device = t->has_device();
1834 } else {
1835 auto d = Dim::unchecked_wrap(self);
1836 orig_levels.append(A, d);
1837 data = d->range();
1838 has_device = false;
1839 }
1840
1841 Slice<DimEntry> flat_positional_dims;
1842 Slice<std::pair<int, int>> to_flatten;
1843 levels.extend(A, orig_levels);
1844
1845 int orig_ndim = ndim_of_levels(levels);
1846 auto append = [&](DimEntry d) {
1847 auto midx = levels.index(d);
1848 if (!midx) {
1849 if (d.is_positional()) {
1850 mpy::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim));
1851 } else {
1852 mpy::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr());
1853 }
1854 }
1855 levels[*midx] = DimEntry();
1856 flat_positional_dims.append(A, d);
1857 };
1858
1859 int n_new_positional = 0;
1860 for (auto i :irange(nargs)) {
1861 mpy::handle arg = args[i];
1862 DimEntry entry = _wrap_dim(arg, orig_ndim, false);
1863 if (!entry.is_none()) {
1864 append(entry);
1865 ++n_new_positional;
1866 } else if (DimList::check(arg)) {
1867 auto dl = DimList::unchecked_wrap(arg);
1868 for (mpy::obj<Dim> & d : dl->dims_) {
1869 append(mpy::hdl<Dim>(d));
1870 ++n_new_positional;
1871 }
1872 } else {
1873 ++n_new_positional;
1874 if (!mpy::is_sequence(arg)) {
1875 mpy::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]");
1876 }
1877 mpy::sequence_view sq(arg);
1878 auto N = sq.size();
1879 to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N));
1880 for (auto j : irange(N)) {
1881 DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false);
1882 if (e.is_none()) {
1883 mpy::raise_error(PyExc_ValueError, "expected a Dim, or int");
1884 }
1885 append(e);
1886 }
1887 }
1888 }
1889
1890 int ndim = 0;
1891 int insert_point = -1;
1892 Slice<DimEntry> new_levels;
1893 for (auto l : levels) {
1894 if (l.is_none()) {
1895 continue;
1896 }
1897 if (l.is_positional()) {
1898 ndim++;
1899 if (insert_point == -1) {
1900 insert_point = new_levels.size();
1901 new_levels.extend(A, flat_positional_dims);
1902 }
1903 }
1904 new_levels.append(A, l);
1905 }
1906 if (insert_point == -1) {
1907 insert_point = new_levels.size();
1908 new_levels.extend(A, flat_positional_dims);
1909 }
1910
1911 at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels);
1912
1913 if (to_flatten.size()) {
1914 Slice<int64_t> view;
1915 auto sz = ndata.sizes();
1916 // before the new positional dims
1917 for (auto i : irange(0, insert_point)) {
1918 view.append(A, sz[i]);
1919 }
1920 int i = 0;
1921 for (auto to_flat : to_flatten) {
1922 for (;i < to_flat.first; ++i) {
1923 view.append(A, sz[insert_point + i]);
1924 }
1925 int64_t new_size = 1;
1926 int last = i + to_flat.second;
1927 for (; i < last; ++i) {
1928 new_size *= sz[insert_point + i];
1929 }
1930 view.append(A, new_size);
1931 }
1932 for (; i < flat_positional_dims.size(); ++i) {
1933 view.append(A, sz[insert_point + i]);
1934 }
1935 // after the new positional dims
1936 for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) {
1937 view.append(A, sz[i]);
1938 }
1939 // we shorted the number of dimension, so remove them from new levels
1940 // we will renumber them later
1941 auto n_to_remove = flat_positional_dims.size() - n_new_positional;
1942 new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice<DimEntry>());
1943 ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end()));
1944 }
1945
1946 // renumber the positional dimension
1947 int seen = 0;
1948 for (auto i : new_levels.reversed_enumerate()) {
1949 if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) {
1950 new_levels[i] = --seen;
1951 }
1952 }
1953 return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release();
1954
1955 PY_END(nullptr)
1956 }
1957
expand(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)1958 static PyObject* expand(PyObject *_,
1959 PyObject *const *args,
1960 Py_ssize_t nargs,
1961 PyObject *kwnames) {
1962 Arena A;
1963 PY_BEGIN
1964 AT_ASSERT(nargs-- > 0);
1965 auto info = TensorInfo::create(A, args++[0], false);
1966 for (auto i : irange(nargs)) {
1967 if (!Dim::check(args[i])) {
1968 maybeInitializeGlobals();
1969 mpy::vector_args vargs(args - 1, nargs + 1, kwnames);
1970 if (THPVariable_Check(args[-1])) {
1971 return torch_Tensor_expand.call_vector(vargs).release();
1972 } else {
1973 return __torch_function__(A, torch_Tensor_expand, vargs, false).release();
1974 }
1975 }
1976 }
1977 const at::Tensor& data = *info.tensor;
1978 auto levels = info.levels;
1979 Slice<DimEntry> new_levels;
1980 Slice<int64_t> sz;
1981 Slice<int64_t> sd;
1982 for (auto i : irange(nargs)) {
1983 auto d = Dim::unchecked_wrap(args[i]);
1984 if (levels.contains(d) || new_levels.contains(d)) {
1985 mpy::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr());
1986 }
1987 new_levels.append(A, d);
1988 sz.append(A, d->size());
1989 sd.append(A, 0);
1990 }
1991 new_levels.extend(A, levels);
1992 at::IntArrayRef osz = data.sizes();
1993 at::IntArrayRef osd = data.strides();
1994 sz.extend(A, osz.begin(), osz.end());
1995 sd.extend(A, osd.begin(), osd.end());
1996 at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset());
1997 return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release();
1998 PY_END(nullptr)
1999 }
2000
2001
_bind_dims_to_size(Arena & A,int64_t sz,int64_t sd,Slice<mpy::hdl<Dim>> dims,Slice<int64_t> & nsz,Slice<int64_t> & nsd)2002 static void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd,
2003 Slice<mpy::hdl<Dim>> dims, Slice<int64_t>& nsz, Slice<int64_t>& nsd) {
2004 int64_t rhs_prod = 1;
2005 for (auto i : dims.enumerate()) {
2006 if (!dims[i]->is_bound()) {
2007 for (auto j : irange(i + 1, dims.size())) {
2008 if (!dims[j]->is_bound()) {
2009 mpy::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr());
2010 }
2011 rhs_prod *= dims[j]->size();
2012 }
2013 if (sz % rhs_prod != 0) {
2014 mpy::tuple tup(dims.size());
2015 for (auto j : dims.enumerate()) {
2016 tup.set(j, dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) : mpy::unicode_from_string("?"));
2017 }
2018 mpy::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr());
2019 }
2020 int64_t inferred_size = sz / rhs_prod;
2021 dims[i]->set_size(inferred_size);
2022 rhs_prod = sz;
2023 break;
2024 }
2025 rhs_prod *= dims[i]->size();
2026 }
2027 if (rhs_prod != sz) {
2028 mpy::tuple tup(dims.size());
2029 for (auto j : dims.enumerate()) {
2030 tup.set(j, mpy::object::borrow(dims[j]));
2031 }
2032 mpy::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr());
2033 }
2034 auto new_strides = A.allocate<int64_t>(dims.size());
2035 auto prev_stride = sd;
2036 for (auto i : dims.reversed_enumerate()) {
2037 new_strides[i] = prev_stride;
2038 prev_stride = dims[i]->size()*prev_stride;
2039 }
2040 for (auto i : dims.enumerate()) {
2041 nsd.append(A, new_strides[i]);
2042 nsz.append(A, dims[i]->size());
2043 }
2044 }
2045
has_dims(mpy::handle d)2046 static bool has_dims(mpy::handle d) {
2047 return Dim::check_exact(d) || Tensor::check_exact(d);
2048 }
2049
2050 struct IndexingInfo {
2051 bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling
2052 bool advanced_indexing; // requires actual lookup
2053 TensorRef self;
2054 Slice<mpy::handle> flat_inputs;
2055 Slice<DimEntry> result_levels;
2056 bool has_device;
2057 };
2058 }
2059
2060 IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<mpy::handle> input, Slice<DimEntry> keys, Slice<mpy::handle> values, bool has_dimpacks_or_none);
2061 namespace{
as_slice(mpy::tuple_view tv)2062 Slice<mpy::handle> as_slice(mpy::tuple_view tv) {
2063 PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0);
2064 return Slice<mpy::handle>((mpy::handle*)begin, (mpy::handle*) (begin + tv.size()));
2065 }
2066
as_slice(mpy::list_view tv)2067 Slice<mpy::handle> as_slice(mpy::list_view tv) {
2068 PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0);
2069 return Slice<mpy::handle>((mpy::handle*)begin, (mpy::handle*) (begin + tv.size()));
2070 }
2071
2072
maybe_dimpack(Slice<mpy::handle> & elements,mpy::handle s,bool check_first=true)2073 bool maybe_dimpack(Slice<mpy::handle>& elements, mpy::handle s, bool check_first=true) {
2074 // can we avoid rechecking?
2075 if (mpy::list_view::check(s)) {
2076 mpy::list_view tv(s);
2077 if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) {
2078 elements = as_slice(tv);
2079 return true;
2080 }
2081 }
2082 // can we avoid rechecking?
2083 if (mpy::tuple_view::check(s)) {
2084 mpy::tuple_view tv(s);
2085 if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) {
2086 elements = as_slice(tv);
2087 return true;
2088 }
2089 }
2090 return false;
2091 };
2092
is_dimpack(mpy::handle s)2093 bool is_dimpack(mpy::handle s) {
2094 Slice<mpy::handle> e;
2095 return maybe_dimpack(e, s);
2096 }
2097
invoke_getitem(Arena & A,const IndexingInfo & iinfo)2098 mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) {
2099 at::Tensor rtensor;
2100 if (iinfo.advanced_indexing) {
2101 auto self_hdl = handle_from_tensor(A, iinfo.self);
2102 auto tup = slice_to_tuple(iinfo.flat_inputs);
2103 // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n";
2104 auto pytensor = mpy::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr()));
2105 rtensor = THPVariable_Unpack(pytensor.ptr());
2106 } else {
2107 // std::cout << "skipping original getindex\n";
2108 rtensor = *iinfo.self;
2109 }
2110 // std::cout << "returning (from_positional)\n";
2111 return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device);
2112 }
2113
index(Arena & A,mpy::handle self,mpy::handle dims,mpy::handle indices)2114 mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indices) {
2115 maybeInitializeGlobals();
2116 Slice<mpy::handle> dims_list;
2117 Slice<mpy::handle> indices_list;
2118 // we allow for matching single dims to multiple dims,
2119 // so we first have to normalize everything into the case where there is a list on lhs and the rhs
2120 bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims);
2121 bool rhs_list = mpy::tuple_view::check(indices) || mpy::list_view::check(indices);
2122 if (lhs_list && rhs_list) {
2123 mpy::sequence_view dv(dims);
2124 mpy::sequence_view ind(indices);
2125 Py_ssize_t N = dv.size();
2126 if (N != ind.size()) {
2127 mpy::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size()));
2128 }
2129 for (auto i : irange(N)) {
2130 dims_list.append(A, A.autorelease(dv[i]));
2131 indices_list.append(A, A.autorelease(ind[i]));
2132 }
2133 } else {
2134 dims_list.append(A, dims);
2135 indices_list.append(A, indices);
2136 }
2137
2138 // dims being indexed can be grouped together into a single index space, and we have to
2139 // flatten them int a single dimension before we can index them...
2140 auto self_info = TensorInfo::create(A, self, false);
2141 auto ndim = self_info.ndim();
2142 Slice<DimEntry> new_levels;
2143 Slice<DimEntry> to_flatten;
2144 Slice<DimEntry> dims_list_flat;
2145 auto parse_dim_entry = [&](mpy::handle s) -> DimEntry {
2146 auto d = _wrap_dim(s, ndim, false);
2147 if (d.is_none()) {
2148 mpy::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr());
2149 }
2150 return d;
2151 };
2152 auto dim_not_present = [&](DimEntry d) {
2153 if (d.is_positional()) {
2154 mpy::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim);
2155 } else {
2156 mpy::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr());
2157 }
2158 };
2159
2160 for (auto i : dims_list.enumerate()) {
2161 Slice<mpy::handle> m;
2162 if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) {
2163 if (m.size() == 0) {
2164 // plausible semantics work for this to have 0 elements (e.g. the index will always be 0)
2165 dims_list_flat.append(A, DimEntry()); // value is just dropped
2166 }
2167 auto first = parse_dim_entry(m[0]);
2168 dims_list_flat.append(A, first);
2169 if (m.size() == 1) {
2170 continue;
2171 }
2172 if (to_flatten.size() == 0) {
2173 new_levels.extend(A, self_info.levels);
2174 }
2175 Slice<DimEntry> rest;
2176 for (auto i : irange(1, m.size())) {
2177 auto d = parse_dim_entry(m[i]);
2178 if (!new_levels.remove(A, d)) {
2179 dim_not_present(d);
2180 }
2181 rest.append(A, d);
2182 }
2183
2184 auto first_idx = new_levels.index(first);
2185 if (!first_idx) {
2186 dim_not_present(first);
2187 }
2188 new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest);
2189 to_flatten.extend(A, rest);
2190 } else {
2191 dims_list_flat.append(A, parse_dim_entry(dims_list[i]));
2192 }
2193 }
2194 if (to_flatten.size() > 0) {
2195 TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels);
2196 at::IntArrayRef sizes = rearranged->sizes();
2197 Slice<int64_t> new_sizes;
2198 Slice<DimEntry> reshape_levels;
2199 for (auto i : new_levels.enumerate()) {
2200 if (to_flatten.contains(new_levels[i])) {
2201 new_sizes.back() *= sizes[i];
2202 } else {
2203 new_sizes.append(A, sizes[i]);
2204 reshape_levels.append(A, new_levels[i]);
2205 }
2206 }
2207 self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end())));
2208
2209 self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op
2210 // we need to be careful not to rely the dimensions size because it doesnt match the size of the whole group
2211 }
2212 bool has_dimpacks = false;
2213 for (auto idx : indices_list) {
2214 if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) {
2215 has_dimpacks = true;
2216 break;
2217 }
2218 }
2219 IndexingInfo info = getsetitem_flat(A, self_info, Slice<mpy::handle>(), dims_list_flat, indices_list, has_dimpacks);
2220 return invoke_getitem(A, info);
2221 }
2222
2223 // true -- the indices were flattend out of a tuple, list or sequence...
2224
slice_from_sequence(Arena & A,mpy::handle value)2225 Slice<mpy::handle> slice_from_sequence(Arena& A, mpy::handle value) {
2226 if (mpy::tuple_view::check(value)) {
2227 return as_slice(mpy::tuple_view(value));
2228 } else if (mpy::list_view::check(value)) {
2229 return as_slice(mpy::list_view(value));
2230 } else {
2231 mpy::sequence_view sv(value);
2232 Slice<mpy::handle> r;
2233 for (auto i : sv.enumerate()) {
2234 r.append(A, A.autorelease(sv[i]));
2235 }
2236 return r;
2237 }
2238 }
2239
extractIndices(Arena & A,mpy::handle index,Slice<mpy::handle> & indices)2240 bool extractIndices(Arena& A, mpy::handle index, Slice<mpy::handle>& indices) {
2241 if (mpy::tuple_view::check(index)) {
2242 indices.extend(A, as_slice(mpy::tuple_view(index)));
2243 return true;
2244 } else if (THPVariable_Check(index.ptr())) {
2245 indices.append(A, index);
2246 return false;
2247 } else if (!mpy::is_sequence(index)) {
2248 indices.append(A, index);
2249 return false;
2250 }
2251 // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors..
2252 mpy::sequence_view sv(index);
2253 if (sv.size() >= 32) {
2254 indices.extend(A, slice_from_sequence(A, index));
2255 return true;
2256 }
2257 for (auto i : sv.enumerate()) {
2258 mpy::handle item;
2259 try {
2260 item = sv[i];
2261 } catch (mpy::exception_set & e) {
2262 PyErr_Clear();
2263 indices.append(A, index);
2264 return false;
2265 }
2266 if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || mpy::is_none(item) || has_dims(item)) {
2267 indices.extend(A, slice_from_sequence(A, index));
2268 return true;
2269 }
2270 }
2271 indices.append(A, index);
2272 return false;
2273 }
2274
getsetitem(Arena & A,mpy::handle self,mpy::handle index,bool tensors_have_dims)2275 IndexingInfo getsetitem(Arena & A, mpy::handle self, mpy::handle index, bool tensors_have_dims) {
2276 bool can_call_original_getitem = !tensors_have_dims;
2277
2278 Slice<mpy::handle> input;
2279 if (has_dims(index)) {
2280 input.append(A, index);
2281 } else {
2282 bool is_sequence = extractIndices(A, index, input);
2283 // nothing about first class dims here, fallback to getitem
2284 if (can_call_original_getitem && !is_sequence) {
2285 return { true };
2286 }
2287 }
2288
2289 int64_t dims_indexed = 0;
2290 int64_t expanding_object = -1;
2291 DimList* unbound_dim_list = nullptr;
2292 auto check_expanding = [&](int64_t i) {
2293 if (expanding_object != -1) {
2294 mpy::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i);
2295 }
2296 expanding_object = i;
2297 };
2298 Slice<int64_t> dimlists;
2299
2300 // calculate how many dimensioned have been indexed in order to compute the size of ...
2301 // or expand a potentially unbound dimension list.
2302
2303 bool has_dimpacks_or_none = false;
2304 for (auto i : input.enumerate()) {
2305 mpy::handle s = input[i];
2306 if (Dim::check_exact(s) || Tensor::check_exact(s)) {
2307 can_call_original_getitem = false;
2308 ++dims_indexed;
2309 } else if (s.ptr() == Py_Ellipsis) {
2310 check_expanding(i);
2311 } else if (DimList::check(s)) {
2312 can_call_original_getitem = false;
2313 auto dl = DimList::unchecked_wrap(s);
2314 if (!dl->is_bound()) {
2315 check_expanding(i);
2316 unbound_dim_list = dl.ptr();
2317 } else {
2318 dims_indexed += dl->dims_.size();
2319 }
2320 dimlists.append(A, i);
2321 } else if (mpy::is_none(s)) {
2322 has_dimpacks_or_none = true;
2323 } else if (is_dimpack(s)) {
2324 can_call_original_getitem = false;
2325 has_dimpacks_or_none = true;
2326 ++dims_indexed;
2327 } else {
2328 ++dims_indexed;
2329 }
2330 }
2331
2332 // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem.
2333 if (can_call_original_getitem) {
2334 return {true};
2335 }
2336
2337 // std::cout << "__getitem__ " << self << " " << index << "\n";
2338
2339 TensorInfo self_info = TensorInfo::create(A, self, false, true);
2340 auto ndim = self_info.ndim();
2341 if (dims_indexed > ndim) {
2342 mpy::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim);
2343 }
2344 // expand any unbound dimension list, or expand ... into individual : slices.
2345 auto expanding_dims = ndim - dims_indexed;
2346 if (expanding_object != -1) {
2347 if (unbound_dim_list) {
2348 unbound_dim_list->bind_len(expanding_dims);
2349 } else {
2350 // ...
2351 Slice<mpy::handle> no_slices;
2352 for (auto i : irange(expanding_dims)) {
2353 (void) i;
2354 no_slices.append(A, no_slice);
2355 }
2356 input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices);
2357 }
2358 }
2359
2360 // flatten out any dimensions stored in dimlist elements directly into the inputs
2361 // std::cout << dimlists << " <- dim lists!\n";
2362 for (int64_t i = dimlists.size() - 1; i >=0; --i) {
2363 auto idx = dimlists[i];
2364 // we added more elements to input because of ...
2365 // so we need to also adjust the index to get back to where the
2366 // dimlist existed
2367 if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) {
2368 idx += expanding_dims;
2369 }
2370 auto dl = DimList::unchecked_wrap(input[idx]);
2371 // XXX would be better if we used an OwnedSlice in DimList
2372 Slice<mpy::handle> more_dims((mpy::handle*) &*dl->dims_.begin(), (mpy::handle*) &*dl->dims_.end());
2373 input.insert(A, input.slice(idx, idx + 1), more_dims);
2374 }
2375
2376 return getsetitem_flat(A, self_info, input, Slice<DimEntry>(), Slice<mpy::handle>(), has_dimpacks_or_none);
2377 }
2378 }
getsetitem_flat(Arena & A,TensorInfo self_info,Slice<mpy::handle> input,Slice<DimEntry> keys,Slice<mpy::handle> values,bool has_dimpacks_or_none)2379 IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice<mpy::handle> input, Slice<DimEntry> keys, Slice<mpy::handle> values, bool has_dimpacks_or_none) {
2380 // At this point:
2381 // ..., DimList have been eliminated
2382 // Dim, Tensor, Tuple[Dim,...], int, slice still remain
2383
2384
2385 // we have to count how many times we see a dimension.
2386 // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing.
2387 Slice<mpy::hdl<Dim>> seen_dims;
2388 Slice<int64_t> seen_dims_nuses;
2389 auto add_dim = [&](mpy::hdl<Dim> entry) {
2390 auto midx = seen_dims.index(entry);
2391 if (!midx) {
2392 seen_dims.append(A, entry);
2393 seen_dims_nuses.append(A, 1);
2394 } else {
2395 ++seen_dims_nuses[*midx];
2396 }
2397 };
2398
2399 Slice<mpy::handle> input_it = input;
2400
2401 Slice<mpy::handle> flat_inputs;
2402 // flat inputs will start with an empty mpy::handle if the
2403 // actual value is in the tensor-like object in the tensor info
2404 Slice<TensorInfo> tensor_inputs;
2405
2406 auto append_flat_handle = [&](mpy::handle h) {
2407 flat_inputs.append(A, h);
2408 tensor_inputs.append(A, TensorInfo());
2409 };
2410 TensorRef device_holding_tensor;
2411 auto append_tensor_input = [&](TensorInfo ti) {
2412 flat_inputs.append(A, mpy::handle());
2413 tensor_inputs.append(A, ti);
2414 if (ti.has_device && !device_holding_tensor) {
2415 device_holding_tensor = ti.tensor;
2416 }
2417 };
2418
2419 Slice<int64_t> nsz;
2420 Slice<int64_t> nsd;
2421 at::IntArrayRef sz = self_info.tensor->sizes();
2422 at::IntArrayRef sd = self_info.tensor->strides();
2423
2424 auto append_size = [&](int i) {
2425 if (has_dimpacks_or_none) {
2426 nsz.append(A, sz[i]);
2427 nsd.append(A, sd[i]);
2428 }
2429 };
2430 // std::cout << "self levels: " << self_info.levels << "\n";
2431
2432 auto parse_nones = [&]() {
2433 while (input_it.size() && mpy::is_none(input_it[0])) {
2434 append_flat_handle(no_slice);
2435 nsz.append(A, 1);
2436 nsd.append(A, 0);
2437 input_it = input_it.slice(1);
2438 }
2439 };
2440
2441
2442 auto append_item = [&](int i, mpy::handle arg) {
2443 if (Dim::check_exact(arg)) {
2444 auto d = Dim::unchecked_wrap(arg);
2445 d->set_size(sz[i]);
2446 add_dim(d);
2447 append_size(i);
2448 append_flat_handle(arg);
2449 return;
2450 }
2451 auto info = TensorInfo::create(A, arg, false, false);
2452 if (info) {
2453 append_size(i);
2454 append_tensor_input(info);
2455 for (auto il : info.levels) {
2456 if (!il.is_positional()) {
2457 add_dim(il.dim());
2458 }
2459 }
2460 return;
2461 }
2462
2463 if (has_dimpacks_or_none) {
2464 Slice<mpy::handle> mp;
2465 if (maybe_dimpack(mp, arg)) {
2466 // dim pack
2467 Slice<mpy::hdl<Dim>> dim_pack;
2468 for (auto d : mp) {
2469 dim_pack.append(A, Dim::wrap(d));
2470 add_dim(dim_pack.back());
2471 append_flat_handle(dim_pack.back());
2472 }
2473 _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd);
2474 return;
2475 }
2476 }
2477
2478 append_size(i);
2479 append_flat_handle(arg);
2480 };
2481
2482 // pair up the indexing expressions with dimension of self it indexes
2483 // self may have first-class dims, which do not participate the indexing.
2484 for (auto i : self_info.levels.enumerate()) {
2485 auto l = self_info.levels[i];
2486 auto idx = keys.index(l);
2487 if (idx) {
2488 append_item(i, values[*idx]);
2489 } else if (l.is_positional()) {
2490 // grab and index from the positional list
2491 parse_nones();
2492 if (!input_it.size()) {
2493 // we might have fewer indices than tensor dimensions,
2494 // which implicitly indexes the remaining dimensions with :
2495 append_flat_handle(no_slice);
2496 append_size(i);
2497 } else {
2498 mpy::handle arg = input_it[0];
2499 input_it = input_it.slice(1);
2500 append_item(i, arg);
2501 }
2502 } else {
2503 add_dim(l.dim());
2504 append_flat_handle(l.dim());
2505 append_size(i);
2506 }
2507 }
2508 // any training Nones may have no existing dimension associated with them in self.
2509 parse_nones();
2510
2511 // we have to restride the tensor to collapse dimension packs and introduce our none dimensions.
2512 if (has_dimpacks_or_none) {
2513 self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset()));
2514 }
2515
2516
2517 // figure out what the shape of the indexing tensors will be
2518 // and what the shape of the resulting tensor will be
2519 Slice<DimEntry> result_levels;
2520 Slice<DimEntry> index_levels;
2521 int64_t tensor_insert_point = -1;
2522 bool requires_getindex = false;
2523 auto mark_tensor_index = [&] {
2524 if (tensor_insert_point == -1) {
2525 tensor_insert_point = result_levels.size();
2526 } else if (tensor_insert_point != result_levels.size()) {
2527 tensor_insert_point = 0;
2528 }
2529 };
2530 for (auto i : flat_inputs.enumerate()) {
2531 auto inp = flat_inputs[i];
2532 if(tensor_inputs[i]) {
2533 requires_getindex = true;
2534 mark_tensor_index();
2535 for (auto l : tensor_inputs[i].levels) {
2536 // std::cout << "Consider to add " << l << "\n";
2537 if (!index_levels.contains(l)) {
2538 index_levels.append(A, l);
2539 }
2540 }
2541 } else if (Dim::check_exact(inp)) {
2542 auto d = Dim::unchecked_wrap(inp);
2543 // dimesions used once are just binding operations
2544 if (1 == seen_dims_nuses[*seen_dims.index(d)]) {
2545 flat_inputs[i] = no_slice;
2546 result_levels.append(A, d);
2547 } else {
2548 requires_getindex = true;
2549 flat_inputs[i] = mpy::handle();
2550 tensor_inputs[i] = TensorInfo {d->range(), Slice<DimEntry>(A, DimEntry(d)), false, TensorRef()};
2551 if (!index_levels.contains(d)) {
2552 index_levels.append(A, d);
2553 }
2554 mark_tensor_index();
2555 }
2556 } else {
2557 if (inp.ptr() != no_slice.ptr()) {
2558 requires_getindex = true;
2559 }
2560 if (!mpy::is_int(inp)) {
2561 // note: actual positional indexes are accurately computed later
2562 result_levels.append(A, -1);
2563 }
2564 }
2565 }
2566
2567 // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert
2568 // the indexing leveles into the result klevels at this spot
2569 if (tensor_insert_point != -1) {
2570 result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels);
2571 }
2572
2573 // std::cout << "flat inputs: " << flat_inputs << "\n";
2574 // std::cout << "result_levels: " << result_levels << "\n";
2575 // std::cout << "index_levels: " << index_levels << "\n";
2576
2577 // get all the tensors to be the right shape for indexing
2578 if (requires_getindex) {
2579 for (auto i : flat_inputs.enumerate()) {
2580 if (tensor_inputs[i]) {
2581 AT_ASSERT(!flat_inputs[i].ptr());
2582 // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n";
2583 TensorRef t = tensor_inputs[i].tensor;
2584 if (!tensor_inputs[i].has_device && device_holding_tensor) {
2585 t = A.autorelease(t->to(device_holding_tensor->device()));
2586 }
2587 flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels));
2588 }
2589 }
2590 }
2591
2592 // previously we didn't know how many positional dimensions there would be so we couldn't number them right
2593 // so fill it in now.
2594 auto seen_positionals = 0;
2595 for (auto i : result_levels.reversed_enumerate()) {
2596 if (result_levels[i].is_positional()) {
2597 result_levels[i] = -(++seen_positionals);
2598 }
2599 }
2600
2601 return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device};
2602 }
2603 namespace{
__getitem__(Arena & A,mpy::handle self,mpy::handle index)2604 mpy::object __getitem__(Arena & A, mpy::handle self, mpy::handle index) {
2605 maybeInitializeGlobals();
2606 auto iinfo = getsetitem(A, self, index, has_dims(self));
2607 if (iinfo.can_call_original) {
2608 return mpy::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr()));
2609 }
2610
2611 return invoke_getitem(A, iinfo);
2612 }
2613
2614
2615
__setitem__(Arena & A,mpy::handle self,mpy::handle index,mpy::handle rhs)2616 void __setitem__(Arena & A, mpy::handle self, mpy::handle index, mpy::handle rhs) {
2617 maybeInitializeGlobals();
2618 auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs));
2619 if (iinfo.can_call_original) {
2620 if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) {
2621 throw mpy::exception_set();
2622 }
2623 return;
2624 }
2625
2626 auto rhs_info = TensorInfo::create(A, rhs, false, false);
2627 if (rhs_info) { // otherwise rhs can be a scalar...
2628 for (auto l : rhs_info.levels) {
2629 if (!iinfo.result_levels.contains(l)) {
2630 if (l.is_positional()) {
2631 mpy::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim());
2632 } else {
2633 auto tup = levels_to_tuple(iinfo.result_levels);
2634 mpy::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr());
2635 }
2636 }
2637 }
2638 auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels);
2639 rhs = handle_from_tensor(A, rhs_matched);
2640 }
2641 self = handle_from_tensor(A, iinfo.self);
2642
2643 if (iinfo.advanced_indexing) {
2644 auto tup = slice_to_tuple(iinfo.flat_inputs);
2645 if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) {
2646 throw mpy::exception_set();
2647 }
2648 } else {
2649 torch_Tensor_copy_.call(self, rhs);
2650 }
2651 }
2652 }
2653
Tensor_getitem(PyObject * self,PyObject * index)2654 PyObject* Tensor_getitem(PyObject* self, PyObject* index) {
2655 Arena A;
2656 PY_BEGIN
2657 return __getitem__(A, self, index).release();
2658 PY_END(nullptr);
2659 }
2660
Tensor_setitem(PyObject * self,PyObject * index,PyObject * value)2661 int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) {
2662 Arena A;
2663 PY_BEGIN
2664 __setitem__(A, self, index, value);
2665 return 0;
2666 PY_END(-1);
2667 }
2668
2669 namespace{
py___getitem__(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2670 PyObject* py___getitem__(PyObject *_,
2671 PyObject *const *args,
2672 Py_ssize_t nargs,
2673 PyObject *kwnames) {
2674 Arena A;
2675 PY_BEGIN
2676 AT_ASSERT(nargs == 2);
2677 return __getitem__(A, args[0], args[1]).release();
2678 PY_END(nullptr)
2679 }
2680
py___setitem__(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2681 PyObject* py___setitem__(PyObject *_,
2682 PyObject *const *args,
2683 Py_ssize_t nargs,
2684 PyObject *kwnames) {
2685 Arena A;
2686 PY_BEGIN
2687 AT_ASSERT(nargs == 3);
2688 __setitem__(A, args[0], args[1], args[2]);
2689 Py_RETURN_NONE;
2690 PY_END(nullptr)
2691 }
2692
2693
py_index(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2694 PyObject* py_index(PyObject *_,
2695 PyObject *const *args,
2696 Py_ssize_t nargs,
2697 PyObject *kwnames) {
2698 Arena A;
2699 PY_BEGIN
2700 mpy::vector_args va(args, nargs, kwnames);
2701 mpy::handle self, dims, indices;
2702 va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3);
2703 return index(A, self, dims, indices).release();
2704 PY_END(nullptr)
2705 }
2706
2707
py_stack(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2708 PyObject* py_stack(PyObject *_,
2709 PyObject *const *args,
2710 Py_ssize_t nargs,
2711 PyObject *kwnames) {
2712 Arena A;
2713 PY_BEGIN
2714 mpy::vector_args va(args, nargs, kwnames);
2715 mpy::handle tensors, new_dim, dim;
2716 va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2);
2717
2718 Slice<DimEntry> result_levels;
2719 Slice<TensorInfo> infos;
2720 mpy::sequence_view sv(tensors);
2721 auto new_dim_d = Dim::wrap(new_dim);
2722 for (auto i : sv.enumerate()) {
2723 infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false));
2724 for (auto l : infos.back().levels) {
2725 if (!result_levels.contains(l)) {
2726 result_levels.append(A, l);
2727 }
2728 }
2729 }
2730 new_dim_d->set_size(infos.size());
2731 std::vector<at::Tensor> inputs;
2732 inputs.reserve(infos.size());
2733 for (auto in : infos) {
2734 inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels));
2735 }
2736 auto ndim = ndim_of_levels(result_levels);
2737 int64_t rawdim = 0;
2738 if (dim.ptr()) {
2739 auto d = _wrap_dim(dim, ndim, false);
2740 auto idx = result_levels.index(d);
2741 if (!idx) {
2742 mpy::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr());
2743 }
2744 rawdim = *idx;
2745 }
2746 auto result = at::stack(inputs, rawdim);
2747 result_levels.insert(A, rawdim, new_dim_d);
2748 return Tensor::from_positional(A, std::move(result), result_levels, true).release();
2749 PY_END(nullptr)
2750 }
2751
py_split(PyObject * _,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2752 PyObject* py_split(PyObject *_,
2753 PyObject *const *args,
2754 Py_ssize_t nargs,
2755 PyObject *kwnames) {
2756 Arena A;
2757 PY_BEGIN
2758 maybeInitializeGlobals();
2759 mpy::vector_args va(args, nargs, kwnames);
2760 mpy::handle self, split_size_or_sections, dim;
2761 va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2);
2762 bool dim_is_object = dim.ptr() && Dim::check_exact(dim);
2763 Slice<mpy::handle> sizes;
2764
2765 bool all_dims = true;
2766 bool all_ints = true;
2767
2768 if (!mpy::is_int(split_size_or_sections)) {
2769 mpy::sequence_view sv(split_size_or_sections);
2770 for (auto i : sv.enumerate()) {
2771 sizes.append(A, A.autorelease(sv[i]));
2772 if (Dim::check_exact(sizes.back())) {
2773 all_ints = false;
2774 } else {
2775 all_dims = false;
2776 }
2777 }
2778 }
2779 if (all_ints) {
2780 if (dim_is_object) {
2781 mpy::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions.");
2782 }
2783 // call original split (if self has dimensions this will use torch function to do the split)
2784 return torch_Tensor_split.call_vector(mpy::vector_args(args, nargs, kwnames)).release();
2785 }
2786 if (!all_dims) {
2787 mpy::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix");
2788 }
2789
2790 auto self_info = TensorInfo::create(A, self, false);
2791 auto ndim = self_info.ndim();
2792 if (!dim_is_object&& ndim == 0) {
2793 mpy::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor");
2794 }
2795 DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim;
2796
2797 auto idx = self_info.levels.index(dim_l);
2798 if (!idx) {
2799 if (!dim.ptr()) {
2800 dim = A.autorelease(mpy::from_int(0));
2801 }
2802 mpy::raise_error(PyExc_TypeError, "tensor does not comtain dimension %R", dim.ptr());
2803 }
2804 Slice<int64_t> indices;
2805
2806 int64_t total_size = 0;
2807 Slice<int64_t> unbound;
2808 for (auto i : sizes.enumerate()) {
2809 auto d = Dim::unchecked_wrap(sizes[i]);
2810 if (d->is_bound()) {
2811 indices.append(A, d->size());
2812 total_size += indices.back();
2813 } else {
2814 indices.append(A, 0);
2815 unbound.append(A, i);
2816 }
2817 }
2818 auto tensor_size = self_info.tensor->sizes()[*idx];
2819
2820 if (unbound.size()) {
2821 if (total_size > tensor_size) {
2822 mpy::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size));
2823 }
2824 auto remaining_size = tensor_size - total_size;
2825 auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size();
2826 for (auto u : unbound) {
2827 auto sz = std::min(chunk_size, remaining_size);
2828 Dim::unchecked_wrap(sizes[u])->set_size(sz);
2829 indices[u] = sz;
2830 remaining_size -= sz;
2831 }
2832 } else if (tensor_size != total_size) {
2833 mpy::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size));
2834 }
2835
2836 auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx);
2837 mpy::tuple result(result_tensors.size());
2838 Slice<DimEntry> new_levels;
2839 new_levels.extend(A, self_info.levels);
2840 for (auto i : sizes.enumerate()) {
2841 new_levels[*idx] = Dim::unchecked_wrap(sizes[i]);
2842 result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true));
2843 }
2844
2845 return result.release();
2846
2847 PY_END(nullptr)
2848 }
2849
_wrap_dims(Arena & A,mpy::handle d,size_t N,bool keepdim)2850 Slice<DimEntry> _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) {
2851 auto de = _wrap_dim(d, N, keepdim);
2852 Slice<DimEntry> r;
2853 if (!de.is_none()) {
2854 r.append(A, de);
2855 } else {
2856 mpy::sequence_view sq(d);
2857 for (auto i : sq.enumerate()) {
2858 r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim));
2859 }
2860 }
2861 return r;
2862 }
2863
2864 struct WrappedOperator : public mpy::base<WrappedOperator> {
2865 mpy::object orig;
2866 PyMethodDef method_def;
2867 mpy::object name, doc;
2868
2869 bool is_pointwise = false;
2870 int64_t dim_offset = 0;
2871 int64_t keepdim_offset = 1;
2872 std::string dim_name;
2873 bool single_dim = false;
2874 bool reduce = true;
2875
2876 static PyTypeObject Type;
2877
init__anon5de4c5481f11::WrappedOperator2878 void init(mpy::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") {
2879 orig = std::move(orig_);
2880 method_def.ml_meth = wrapper_implementation;
2881 name = orig.attr("__name__");
2882 doc = orig.attr("__doc__");
2883 dim_name = std::move(dim_name_);
2884 if (!mpy::is_none(doc) && !dim_name.empty()) {
2885 doc = mpy::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str());
2886 }
2887 method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr());
2888 method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr());
2889 method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS;
2890 }
2891
function__anon5de4c5481f11::WrappedOperator2892 mpy::object function() {
2893 return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr()));
2894 }
2895
2896 };
2897 }
2898
2899 PyTypeObject WrappedOperator::Type = {
2900 PyVarObject_HEAD_INIT(NULL, 0)
2901 "_C.WrappedOperator", /* tp_name */
2902 sizeof(WrappedOperator), /* tp_basicsize */
2903 0, /* tp_itemsize */
2904 WrappedOperator::dealloc_stub, /* tp_dealloc */
2905 0, /* tp_vectorcall_offset */
2906 0, /* tp_getattr */
2907 0, /* tp_setattr */
2908 0, /* tp_as_async */
2909 0, /* tp_repr */
2910 0, /* tp_as_number */
2911 0, /* tp_as_sequence */
2912 0, /* tp_as_mapping */
2913 0, /* tp_hash */
2914 0, /* tp_call */
2915 0, /* tp_str */
2916 0, /* tp_getattro */
2917 0, /* tp_setattro */
2918 0, /* tp_as_buffer */
2919 Py_TPFLAGS_DEFAULT, /* tp_flags */
2920 "Wrapped Object Holder", /* tp_doc */
2921 0, /* tp_traverse */
2922 0, /* tp_clear */
2923 0, /* tp_richcompare */
2924 0, /* tp_weaklistoffset */
2925 0, /* tp_iter */
2926 0, /* tp_iternext */
2927 0, /* tp_methods */
2928 0, /* tp_members */
2929 0, /* tp_getset */
2930 0, /* tp_base */
2931 0, /* tp_dict */
2932 0, /* tp_descr_get */
2933 0, /* tp_descr_set */
2934 0, /* tp_dictoffset */
2935 0, /* tp_init */
2936 0, /* tp_alloc */
2937 WrappedOperator::new_stub, /* tp_new */
2938 };
2939
2940 namespace{
patched_dim_method(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)2941 PyObject* patched_dim_method(PyObject * self_,
2942 PyObject *const *args,
2943 Py_ssize_t nargs,
2944 PyObject *kwnames) {
2945 Arena A;
2946 auto self = WrappedOperator::unchecked_wrap(self_);
2947 PY_BEGIN
2948
2949 mpy::vector_args va(args, nargs, kwnames);
2950
2951 auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle {
2952 auto offset = offset_ + 1; // do not include self
2953 auto idx = va.index(name, offset);
2954 return idx == -1 ? mpy::handle() : va[idx];
2955 };
2956 Slice<mpy::handle> patched_args;
2957 patched_args.extend(A, va.begin(), va.end());
2958 auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) {
2959 auto offset = offset_ + 1; // do not include self
2960 auto idx = va.index(name, offset);
2961 if (idx == -1) {
2962 mpy::raise_error(PyExc_ValueError, "Missing argument %s", name);
2963 }
2964 patched_args[idx] = value;
2965 };
2966
2967 auto dim = _getarg(self->dim_name.c_str(), self->dim_offset);
2968 if (!dim.ptr()) {
2969 auto info = TensorInfo::create(A, args[0], true);
2970 EnableAllLayers l(A, info.levels);
2971 l.inplace_update_layers(info.batchedtensor, info.levels);
2972 patched_args[0] = handle_from_tensor(A, info.batchedtensor);
2973 auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames);
2974 return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release();
2975 }
2976
2977 auto info = TensorInfo::create(A, args[0]);
2978 auto keepdim = false;
2979 if (self->reduce) {
2980 auto py_keepdim = _getarg("keepdim", self->keepdim_offset);
2981 if (py_keepdim.ptr()) {
2982 keepdim = mpy::to_bool(py_keepdim);
2983 }
2984 }
2985
2986 auto ndim = info.ndim();
2987 auto dims = _wrap_dims(A, dim, ndim, keepdim);
2988 Slice<int64_t> dim_indices;
2989 auto seen = A.allocate<bool>(info.levels.size());
2990 std::fill(seen, seen + info.levels.size(), false);
2991
2992 for (auto d : dims) {
2993 auto midx = info.levels.index(d);
2994 if (!midx) {
2995 auto tup = levels_to_tuple(info.levels);
2996 mpy::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr());
2997 }
2998 seen[*midx] = true;
2999 dim_indices.append(A, *midx);
3000 }
3001 Slice<DimEntry> new_levels;
3002 if (self->reduce && !keepdim) {
3003 for (auto i : info.levels.enumerate()) {
3004 if (!seen[i]) {
3005 new_levels.append(A, info.levels[i]);
3006 }
3007 }
3008 } else {
3009 new_levels = info.levels;
3010 }
3011 mpy::object py_indices;
3012 if (dim_indices.size() == 1) {
3013 py_indices = mpy::from_int(dim_indices[0]);
3014 } else {
3015 mpy::tuple tup(dim_indices.size());
3016 for (auto i : dim_indices.enumerate()) {
3017 tup.set(i, mpy::from_int(dim_indices[i]));
3018 }
3019 py_indices = std::move(tup);
3020 }
3021 _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices);
3022 patched_args[0] = handle_from_tensor(A, info.tensor);
3023 auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames);
3024 auto wrap = [&](mpy::handle h) {
3025 if (THPVariable_Check(h.ptr())) {
3026 return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device));
3027 }
3028 return h;
3029 };
3030 return tree_map(A, wrap, r).release();
3031 PY_END(nullptr)
3032 }
3033
_wrap(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3034 PyObject* _wrap(PyObject * self_,
3035 PyObject *const *args,
3036 Py_ssize_t nargs,
3037 PyObject *kwnames) {
3038 Arena A;
3039 PY_BEGIN
3040
3041 #define ARGS(_) _(mpy::handle, orig) _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \
3042 _(mpy::handle, dim_name) _(mpy::handle, single_dim) _(mpy::handle, reduce)
3043 MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS)
3044
3045 std::string dim_name_str;
3046 if (dim_name.ptr()) {
3047 dim_name_str = PyUnicode_AsUTF8(dim_name.ptr());
3048 } else {
3049 dim_name_str = "dim";
3050 }
3051 auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str));
3052 if (dim_offset.ptr()) {
3053 info->dim_offset = mpy::to_int(dim_offset);
3054 }
3055 if (keepdim_offset.ptr()) {
3056 info->keepdim_offset = mpy::to_int(keepdim_offset);
3057 }
3058
3059 if (single_dim.ptr()) {
3060 info->single_dim = mpy::to_bool(single_dim);
3061 }
3062 if (reduce.ptr()) {
3063 info->reduce = mpy::to_bool(reduce);
3064 }
3065 return info->function().release();
3066 #undef ARGS
3067
3068 PY_END(nullptr)
3069 }
3070
call_torch_function(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3071 PyObject* call_torch_function(PyObject *self,
3072 PyObject *const *args,
3073 Py_ssize_t nargs,
3074 PyObject *kwnames) {
3075 PY_BEGIN
3076 Arena A;
3077 maybeInitializeGlobals();
3078 auto info = WrappedOperator::unchecked_wrap(self);
3079 return __torch_function__(A, info->orig, mpy::vector_args(args, nargs, kwnames), info->is_pointwise).release();
3080 PY_END(nullptr)
3081 }
3082
_wrap_method(PyObject * self,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3083 PyObject* _wrap_method(PyObject *self,
3084 PyObject *const *args,
3085 Py_ssize_t nargs,
3086 PyObject *kwnames) {
3087 PY_BEGIN
3088 AT_ASSERT(nargs == 2);
3089 // XXX - ignore python function wrapped, we will call torch function directly
3090 mpy::handle orig = args[0];
3091 if (!pointwise.ptr()) {
3092 auto dim = mpy::import("functorch.dim");
3093 pointwise = dim.attr("pointwise");
3094 }
3095 auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) call_torch_function);
3096 info->is_pointwise = pointwise.contains(orig);
3097 return PyInstanceMethod_New(info->function().release());
3098 PY_END(nullptr);
3099 }
3100
3101
Tensor_sum(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3102 PyObject* Tensor_sum(PyObject * self_,
3103 PyObject *const *args,
3104 Py_ssize_t nargs,
3105 PyObject *kwnames) {
3106 Arena A;
3107 PY_BEGIN
3108 maybeInitializeGlobals();
3109 mpy::vector_args va(args, nargs, kwnames);
3110 auto self_ = Tensor::unchecked_wrap(args[0]);
3111 auto d = self_->delayed();
3112 if (!d) {
3113 return _Tensor_sum.call_vector(va).release();
3114 }
3115 mpy::handle self, dim, keepdim, dtype;
3116 va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1);
3117
3118 if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) {
3119 // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n";
3120 return _Tensor_sum.call_vector(va).release();
3121 }
3122 auto levels = self_->levels();
3123
3124 auto N = ndim_of_levels(levels);
3125 auto reduced_dims = _wrap_dims(A, dim, N, false);
3126
3127 return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release();
3128 PY_END(nullptr)
3129 }
3130
_parse_test(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3131 PyObject* _parse_test(PyObject * self_,
3132 PyObject *const *args,
3133 Py_ssize_t nargs,
3134 PyObject *kwnames) {
3135 PY_BEGIN
3136 maybeInitializeGlobals();
3137
3138 int required = mpy::to_int(args[0]);
3139 int kwonly = mpy::to_int(args[1]);
3140
3141 mpy::vector_args va(args + 2, nargs - 2, kwnames);
3142
3143
3144 mpy::handle a, b, c, d;
3145 va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly);
3146 mpy::tuple r(4);
3147 r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None));
3148 r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None));
3149 r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None));
3150 r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None));
3151 return r.release();
3152
3153 PY_END(nullptr)
3154 }
3155
_set_pointwise_optimize(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3156 PyObject* _set_pointwise_optimize(PyObject * self_,
3157 PyObject *const *args,
3158 Py_ssize_t nargs,
3159 PyObject *kwnames) {
3160 PY_BEGIN
3161 mpy::handle value;
3162 mpy::vector_args va(args, nargs, kwnames);
3163 va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1);
3164 pointwise_optimize = mpy::to_bool(value);
3165 Py_RETURN_NONE;
3166 PY_END(nullptr)
3167 }
3168
_patch_tensor_class(PyObject * self_,PyObject * const * args,Py_ssize_t nargs,PyObject * kwnames)3169 PyObject* _patch_tensor_class(PyObject * self_,
3170 PyObject *const *args,
3171 Py_ssize_t nargs,
3172 PyObject *kwnames) {
3173 PY_BEGIN
3174
3175 auto torch = mpy::import("torch");
3176 auto py_TensorBase = torch.attr("_C").attr("TensorBase");
3177 replaceMappingIfMatches(py_TensorBase);
3178
3179 Py_RETURN_NONE;
3180 PY_END(nullptr)
3181 }
3182
3183
3184 const char* dims_doc = R"""(
3185 dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...]
3186
3187 Creates and returns one or more Dim objects.
3188
3189 Arg:
3190 n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified.
3191 sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be
3192 created, specifying each dimensions size, or None to leave the size unset.
3193
3194 Example::
3195 >>> batch, channel, width, height = dims(4)
3196 >>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224])
3197 )""";
3198
3199 PyMethodDef methods[] = {
3200 {"dims", (PyCFunction)(void*) _dims<create_dim>, METH_FASTCALL | METH_KEYWORDS, dims_doc},
3201 {"dimlists", (PyCFunction)(void*) _dims<create_dimlist>, METH_FASTCALL | METH_KEYWORDS},
3202 {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS},
3203 {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS},
3204 {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS},
3205 {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS},
3206 {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS},
3207 {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS},
3208 {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS},
3209 {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS},
3210 {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS},
3211 {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS},
3212 {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS},
3213 {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS},
3214 {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS},
3215 {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS},
3216 {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS},
3217 {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS},
3218 {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS},
3219 {NULL, NULL, 0, NULL} /* Sentinel */
3220 };
3221
3222 struct PyModuleDef module_def = {
3223 PyModuleDef_HEAD_INIT,
3224 "_C", /* name of module */
3225 NULL, /* module documentation, may be NULL */
3226 -1, /* size of per-interpreter state of the module,
3227 or -1 if the module keeps state in global variables. */
3228 methods
3229 };
3230 }
3231
Dim_init()3232 PyObject* Dim_init() {
3233 Arena A;
3234 try {
3235 mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def));
3236 Dim::ready(mod, "Dim");
3237 DimList::ready(mod, "DimList");
3238 Tensor::ready(mod, "Tensor");
3239 WrappedOperator::ready(mod, "_WrappedOperator");
3240 Py_INCREF(&PyInstanceMethod_Type);
3241 PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type);
3242
3243 initializeGlobals(A);
3244 return mod.release();
3245 } catch(mpy::exception_set& err) {
3246 return nullptr;
3247 }
3248 }
3249
3250 #endif
3251