xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_variable_indexing.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SymInt.h>
4 #include <torch/csrc/autograd/python_variable.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/pybind.h>
7 #include <torch/csrc/utils/python_symnode.h>
8 
9 namespace torch::autograd {
10 
11 struct UnpackedSlice {
12   c10::SymInt start;
13   c10::SymInt stop;
14   c10::SymInt step;
15 };
16 
17 // This mirrors Cpython's PySlice_Unpack method
__PySlice_Unpack(PyObject * _r)18 inline UnpackedSlice __PySlice_Unpack(PyObject* _r) {
19   PySliceObject* r = (PySliceObject*)_r;
20   /* this is harder to get right than you might think */
21 
22   c10::SymInt start_sym, stop_sym, step_sym;
23 
24   auto clip_val = [](Py_ssize_t val) {
25     if (val < c10::SymInt::min_representable_int()) {
26       auto r = PyErr_WarnEx(
27           PyExc_UserWarning,
28           "Truncating the start/stop/step "
29           "of slice. This is likely because of "
30           "saved old models when the start/stop/step were larger.",
31           1);
32       if (r != 0) {
33         throw python_error();
34       }
35       return (Py_ssize_t)(c10::SymInt::min_representable_int());
36     }
37     return val;
38   };
39 
40   if (r->step == Py_None) {
41     step_sym = c10::SymInt(1);
42   } else {
43     if (torch::is_symint(r->step)) {
44       step_sym = py::handle(r->step).cast<c10::SymInt>();
45     } else {
46       Py_ssize_t step = 0;
47       if (!_PyEval_SliceIndex(r->step, &step)) {
48         throw python_error();
49       }
50       if (step == 0) {
51         PyErr_SetString(PyExc_ValueError, "slice step cannot be zero");
52       }
53 
54       step = clip_val(step);
55       step_sym = c10::SymInt(step);
56     }
57   }
58 
59   if (torch::is_symint(r->start)) {
60     start_sym = py::handle(r->start).cast<c10::SymInt>();
61   } else if (r->start == Py_None) {
62     start_sym = c10::SymInt(step_sym < 0 ? PY_SSIZE_T_MAX : 0);
63   } else {
64     Py_ssize_t start = 0;
65     if (!_PyEval_SliceIndex(r->start, &start)) {
66       throw python_error();
67     }
68     start = clip_val(start);
69     start_sym = c10::SymInt(start);
70   }
71 
72   if (torch::is_symint(r->stop)) {
73     stop_sym = py::handle(r->stop).cast<c10::SymInt>();
74   } else if (r->stop == Py_None) {
75     stop_sym = c10::SymInt(
76         step_sym < 0 ? c10::SymInt::min_representable_int() : PY_SSIZE_T_MAX);
77   } else {
78     Py_ssize_t stop = 0;
79     if (!_PyEval_SliceIndex(r->stop, &stop)) {
80       throw python_error();
81     }
82     stop = clip_val(stop);
83     stop_sym = c10::SymInt(stop);
84   }
85 
86   return UnpackedSlice{
87       std::move(start_sym), std::move(stop_sym), std::move(step_sym)};
88 }
89 
90 Py_ssize_t THPVariable_length(PyObject* self);
91 PyObject* THPVariable_getitem(PyObject* self, PyObject* index);
92 int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* value);
93 
94 Variable valueToTensor(
95     c10::TensorOptions options,
96     PyObject* value,
97     const at::Device& device);
98 
99 } // namespace torch::autograd
100