xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_arg_parser.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // Parse arguments to Python functions implemented in C++
4 // This is similar to PyArg_ParseTupleAndKeywords(), but specifically handles
5 // the types relevant to PyTorch and distinguishes between overloaded function
6 // signatures.
7 //
8 // Example:
9 //
10 //   static PythonArgParser parser({
11 //     "norm(Scalar p, int64_t dim, bool keepdim=False)",
12 //     "norm(Scalar p=2)",
13 //   });
14 //   ParsedArgs<3> parsed_args;
15 //   auto r = parser.parse(args, kwargs, parsed_args);
16 //   if (r.idx == 0) {
17 //     norm(r.scalar(0), r.int64(1), r.bool(0));
18 //   } else {
19 //     norm(r.scalar(0));
20 //   }
21 //
22 // We auto-generate most uses of PythonArgParser; the generated files
23 // are torch/csrc/autograd/generated/python_*.cpp
24 //
25 // Some gotchas that you should watch out for:
26 //
27 //    - Note [Order of overloads matters]
28 //      Order of overloads matters.  A set of input arguments may
29 //      bind to multiple argument specs; we will always pick the
30 //      first one in PythonArgParser.  However, when you are writing
31 //      overloads in, e.g., native_functions.yaml, you don't have to
32 //      worry about what order you write them, because the code
33 //      generation logic always gives the overloads a canonical
34 //      order, where Tensor overloads come first, before Scalar overloads.
35 //      This logic is in sort_declarations in
36 //      tools/autograd/gen_python_functions.py
37 //
38 //    - Zero-dim tensors (e.g., torch.tensor(2)) bind to both
39 //      Scalar and Tensor, UNLESS they require grad (in which case
40 //      they only bind to Tensor).
41 
42 #include <pybind11/pytypes.h>
43 #include <torch/csrc/python_headers.h>
44 
45 #include <torch/csrc/Device.h>
46 #include <torch/csrc/Dtype.h>
47 #include <torch/csrc/DynamicTypes.h>
48 #include <torch/csrc/Exceptions.h>
49 #include <torch/csrc/Export.h>
50 #include <torch/csrc/Generator.h>
51 #include <torch/csrc/Layout.h>
52 #include <torch/csrc/MemoryFormat.h>
53 #include <torch/csrc/QScheme.h>
54 #include <torch/csrc/Stream.h>
55 #include <torch/csrc/autograd/python_variable.h>
56 #include <torch/csrc/autograd/variable.h>
57 #include <torch/csrc/dynamo/eval_frame.h>
58 #include <torch/csrc/jit/frontend/tracer.h>
59 #include <torch/csrc/python_dimname.h>
60 #include <torch/csrc/tensor/python_tensor.h>
61 #include <torch/csrc/utils/disable_torch_function.h>
62 #include <torch/csrc/utils/object_ptr.h>
63 #include <torch/csrc/utils/pybind.h>
64 #include <torch/csrc/utils/python_numbers.h>
65 #include <torch/csrc/utils/python_strings.h>
66 #include <torch/csrc/utils/python_symnode.h>
67 #include <torch/csrc/utils/six.h>
68 
69 #include <ATen/DeviceAccelerator.h>
70 #include <ATen/PythonTorchFunctionTLS.h>
71 #include <ATen/core/Tensor.h>
72 #include <c10/util/Exception.h>
73 #include <c10/util/irange.h>
74 
75 #include <c10/core/SymFloat.h>
76 #include <c10/core/SymNodeImpl.h>
77 
78 #include <c10/core/DispatchKeySet.h>
79 #include <array>
80 #include <cstddef>
81 #include <string>
82 #include <vector>
83 
THPUtils_checkScalar(PyObject * obj)84 inline bool THPUtils_checkScalar(PyObject* obj) {
85 #ifdef USE_NUMPY
86   if (torch::utils::is_numpy_scalar(obj)) {
87     return true;
88   }
89 #endif
90   return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
91       torch::is_symint(py::handle(obj)) ||
92       torch::is_symfloat(py::handle(obj)) || torch::is_symbool(py::handle(obj));
93 }
94 
95 namespace torch {
96 
97 bool should_allow_numbers_as_tensors(const std::string& name);
98 
99 enum class ParameterType {
100   TENSOR,
101   SCALAR,
102   INT64,
103   SYM_INT,
104   DOUBLE,
105   COMPLEX,
106   TENSOR_LIST,
107   INT_LIST,
108   GENERATOR,
109   BOOL,
110   STORAGE,
111   PYOBJECT,
112   SCALARTYPE,
113   LAYOUT,
114   MEMORY_FORMAT,
115   DEVICE,
116   STREAM,
117   STRING,
118   DIMNAME,
119   DIMNAME_LIST,
120   QSCHEME,
121   FLOAT_LIST,
122   SCALAR_LIST,
123   SYM_INT_LIST,
124   DISPATCH_KEY_SET
125 };
126 
127 struct FunctionParameter;
128 struct FunctionSignature;
129 struct PythonArgs;
130 
131 // Contains bound Python arguments in declaration order
132 template <int N>
133 struct ParsedArgs {
ParsedArgsParsedArgs134   ParsedArgs() : args() {}
135   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
136   PyObject* args[N];
137 };
138 
139 // A PythonArgParser contains a list of valid signatures. Instances are
140 // typically global variables and should be immutable.
141 struct PYBIND11_EXPORT PythonArgParser {
142   explicit PythonArgParser(
143       const std::vector<std::string>& fmts,
144       bool traceable = false);
145 
146   // meant only for `torch` functions.
147   template <int N>
148   inline PythonArgs parse(
149       PyObject* self,
150       PyObject* args,
151       PyObject* kwargs,
152       ParsedArgs<N>& dst);
153 
154   template <int N>
155   inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst);
156 
157   inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst);
158 
159   // Formatted strings of non-hidden signatures
160   std::vector<std::string> get_signatures() const;
161 
162  private:
163   [[noreturn]] void print_error(
164       PyObject* self,
165       PyObject* args,
166       PyObject* kwargs,
167       // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
168       PyObject* parsed_args[]);
169   void check_deprecated(const FunctionSignature& signature);
170   PythonArgs raw_parse(
171       PyObject* self,
172       PyObject* args,
173       PyObject* kwargs,
174       // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
175       PyObject* parsed_args[]);
176 
177   std::vector<FunctionSignature> signatures_;
178   std::string function_name;
179   size_t max_args;
180   bool traceable;
181 };
182 
183 // FunctionSignature represents a single valid signature for a Python function.
184 // It is immutable once constructed. The contained data can be concurrently
185 // accessed by multiple calls.
186 struct FunctionSignature {
187   explicit FunctionSignature(const std::string& fmt, int index);
188 
189   bool parse(
190       PyObject* self,
191       PyObject* args,
192       PyObject* kwargs,
193       // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
194       PyObject* dst[],
195       std::vector<PyObject*>& overloaded_args,
196       bool raise_exception);
197 
198   std::string toString() const;
199 
200   std::string name;
201   std::vector<FunctionParameter> params;
202   size_t min_args;
203   size_t max_args;
204   size_t max_pos_args;
205   int index;
206   bool hidden;
207   bool deprecated;
208 };
209 
210 // PythonArgs contains bound Python arguments for an actual invocation
211 // along with references to the matched signature.
212 struct PythonArgs {
PythonArgsPythonArgs213   PythonArgs(
214       bool traceable,
215       const FunctionSignature& signature,
216       PyObject** args,
217       std::vector<PyObject*> overloaded_args)
218       : idx(signature.index),
219         traceable(traceable),
220         signature(signature),
221         args(args),
222         overloaded_args(std::move(overloaded_args)) {}
223 
224   int idx;
225   bool traceable;
226   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
227   const FunctionSignature& signature;
228   PyObject** args;
229   std::vector<PyObject*> overloaded_args; // NOTE: borrowed references
230 
231   inline bool has_torch_function();
232   inline std::string get_func_name();
233   inline at::Tensor tensor(int i);
234   inline std::optional<at::Tensor> optionalTensor(int i);
235   inline at::Scalar scalar(int i);
236   inline at::Scalar scalarWithDefault(int i, const at::Scalar& default_scalar);
237   inline std::vector<at::Scalar> scalarlist(int i);
238   inline std::vector<at::Tensor> tensorlist(int i);
239   inline torch::List<std::optional<at::Tensor>> list_of_optional_tensors(int i);
240   template <int N>
241   inline std::array<at::Tensor, N> tensorlist_n(int i);
242   inline std::vector<int64_t> intlist(int i);
243   inline std::vector<c10::SymInt> symintlist(int i);
244   inline c10::OptionalArray<int64_t> intlistOptional(int i);
245   inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
246   inline std::vector<int64_t> intlistWithDefault(
247       int i,
248       std::vector<int64_t> default_intlist);
249   inline std::optional<at::Generator> generator(int i);
250   inline at::Storage storage(int i);
251   inline at::Storage storage(
252       int i,
253       at::ScalarType& storage_scalar_type,
254       bool& is_typed_storage);
255   inline c10::Stream stream(int i);
256   inline at::ScalarType scalartype(int i);
257   inline at::ScalarType scalartypeWithDefault(
258       int i,
259       at::ScalarType default_scalartype);
260   inline std::optional<at::ScalarType> scalartypeOptional(int i);
261   inline std::optional<at::Scalar> scalarOptional(int i);
262   inline std::optional<int64_t> toInt64Optional(int i);
263   inline std::optional<c10::SymInt> toSymIntOptional(int i);
264   inline std::optional<bool> toBoolOptional(int i);
265   inline std::optional<double> toDoubleOptional(int i);
266   inline c10::OptionalArray<double> doublelistOptional(int i);
267   inline std::vector<double> doublelist(int i);
268   inline std::vector<double> getDoublelist(int i);
269   inline at::Layout layout(int i);
270   inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
271   inline std::optional<at::Layout> layoutOptional(int i);
272   inline at::Device device(int i);
273   inline at::Device deviceWithDefault(int i, const at::Device& default_device);
274   inline std::optional<at::Device> deviceOptional(int i);
275   inline at::Dimname dimname(int i);
276   inline std::vector<at::Dimname> dimnamelist(int i);
277   inline std::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
278   inline at::MemoryFormat memoryformat(int i);
279   inline std::optional<at::MemoryFormat> memoryformatOptional(int i);
280   inline at::QScheme toQScheme(int i);
281   inline std::string string(int i);
282   inline std::string stringWithDefault(int i, const std::string& default_str);
283   inline std::optional<std::string> stringOptional(int i);
284   inline c10::string_view stringView(int i);
285   inline c10::string_view stringViewWithDefault(
286       int i,
287       const c10::string_view default_str);
288   inline std::optional<c10::string_view> stringViewOptional(int i);
289   inline PyObject* pyobject(int i);
290   inline int64_t toInt64(int i);
291   inline c10::SymInt toSymInt(int i);
292   inline c10::SymBool toSymBool(int i);
293   inline int64_t toInt64WithDefault(int i, int64_t default_int);
294   inline double toDouble(int i);
295   inline double toDoubleWithDefault(int i, double default_double);
296   inline c10::complex<double> toComplex(int i);
297   inline c10::complex<double> toComplexWithDefault(
298       int i,
299       c10::complex<double> default_complex);
300   inline bool toBool(int i);
301   inline bool toBoolWithDefault(int i, bool default_bool);
302   inline bool isNone(int i);
303   inline std::optional<c10::DispatchKeySet> toDispatchKeySetOptional(int i);
304 
305  private:
306   at::Tensor tensor_slow(int i);
307   at::Scalar scalar_slow(int i);
308   at::Scalar scalar_slow(PyObject* arg);
309 };
310 
311 // FunctionParameter is a single formal parameter of a Python function.
312 // It is immutable once constructed.
313 struct FunctionParameter {
314   FunctionParameter(const std::string& fmt, bool keyword_only);
315 
316   bool check(
317       PyObject* obj,
318       std::vector<PyObject*>& overloaded_args,
319       int argnum,
320       int64_t* failed_idx = nullptr);
321 
322   void set_default_str(const std::string& str);
323   std::string type_name() const;
324 
325   ParameterType type_;
326   bool optional;
327   bool allow_none;
328   bool keyword_only;
329   bool allow_numbers_as_tensors = false;
330   int size;
331   std::string name;
332   // having this as a raw PyObject * will presumably leak it, but these are only
333   // held by static objects anyway, and Py_Finalize can already be called when
334   // this is destructed.
335   PyObject* python_name;
336   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
337   at::SmallVector<PyObject*, 5> numpy_python_names;
338   at::Scalar default_scalar;
339   std::vector<int64_t> default_intlist;
340   std::string default_string;
341   union {
342     bool default_bool;
343     int64_t default_int;
344     double default_double;
345     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
346     double default_complex[2]; // see Scalar
347     at::ScalarType default_scalartype;
348     at::Layout default_layout;
349   };
350   std::string default_value;
351 };
352 
353 template <int N>
parse(PyObject * self,PyObject * args,PyObject * kwargs,ParsedArgs<N> & dst)354 inline PythonArgs PythonArgParser::parse(
355     PyObject* self,
356     PyObject* args,
357     PyObject* kwargs,
358     ParsedArgs<N>& dst) {
359   TORCH_CHECK_VALUE(
360       N >= max_args,
361       "PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected ",
362       max_args,
363       " (got ",
364       N,
365       ")");
366   return raw_parse(self, args, kwargs, dst.args);
367 }
368 
369 template <int N>
parse(PyObject * args,PyObject * kwargs,ParsedArgs<N> & dst)370 inline PythonArgs PythonArgParser::parse(
371     PyObject* args,
372     PyObject* kwargs,
373     ParsedArgs<N>& dst) {
374   return parse(nullptr, args, kwargs, dst);
375 }
376 
parse(PyObject * self,ParsedArgs<0> & dst)377 inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) {
378   return parse(self, nullptr, nullptr, dst);
379 }
380 
has_torch_function()381 inline bool PythonArgs::has_torch_function() {
382   return !overloaded_args.empty() || at::impl::torch_function_mode_enabled();
383 }
384 
get_func_name()385 inline std::string PythonArgs::get_func_name() {
386   return signature.name;
387 }
388 
389 // TODO: this can return MaybeOwned
tensor(int i)390 inline at::Tensor PythonArgs::tensor(int i) {
391   if (args[i] && THPVariable_CheckExact(args[i])) {
392     return THPVariable_Unpack(args[i]);
393   }
394   return tensor_slow(i);
395 }
396 
optionalTensor(int i)397 inline std::optional<at::Tensor> PythonArgs::optionalTensor(int i) {
398   at::Tensor t = tensor(i);
399   // NOLINTNEXTLINE(bugprone-branch-clone)
400   if (t.defined()) {
401     return t;
402   } else {
403     return std::nullopt;
404   }
405 }
406 
scalar(int i)407 inline at::Scalar PythonArgs::scalar(int i) {
408   if (!args[i])
409     return signature.params[i].default_scalar;
410   return scalar_slow(i);
411 }
412 
scalarlist(int i)413 inline std::vector<at::Scalar> PythonArgs::scalarlist(int i) {
414   if (!args[i])
415     return std::vector<at::Scalar>();
416   auto tuple = six::isTuple(args[i]);
417   THPObjectPtr arg = six::maybeAsTuple(args[i]);
418   // NOLINTNEXTLINE(bugprone-branch-clone)
419   auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
420   std::vector<at::Scalar> res(size);
421   for (const auto idx : c10::irange(size)) {
422     PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
423                           : PyList_GET_ITEM(arg.get(), idx);
424     res[idx] = scalar_slow(obj);
425   }
426   return res;
427 }
428 
scalarWithDefault(int i,const at::Scalar & default_scalar)429 inline at::Scalar PythonArgs::scalarWithDefault(
430     int i,
431     const at::Scalar& default_scalar) {
432   if (!args[i])
433     return default_scalar;
434   return scalar_slow(i);
435 }
436 
scalarOptional(int i)437 inline std::optional<at::Scalar> PythonArgs::scalarOptional(int i) {
438   if (!args[i])
439     return std::nullopt;
440   return scalar_slow(i);
441 }
442 
tensorlist(int i)443 inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
444   if (!args[i])
445     return std::vector<at::Tensor>();
446   auto tuple = six::isTuple(args[i]);
447   THPObjectPtr arg = six::maybeAsTuple(args[i]);
448   // NOLINTNEXTLINE(bugprone-branch-clone)
449   auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
450   std::vector<at::Tensor> res(size);
451   for (const auto idx : c10::irange(size)) {
452     PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
453                           : PyList_GET_ITEM(arg.get(), idx);
454     // This is checked by the argument parser so it's safe to cast without
455     // checking if this is a tensor first
456     res[idx] = THPVariable_Unpack(obj);
457   }
458   return res;
459 }
460 
461 inline torch::List<std::optional<at::Tensor>> PythonArgs::
list_of_optional_tensors(int i)462     list_of_optional_tensors(int i) {
463   if (!args[i])
464     return torch::List<std::optional<at::Tensor>>();
465   auto tuple = six::isTuple(args[i]);
466   THPObjectPtr arg = six::maybeAsTuple(args[i]);
467   // NOLINTNEXTLINE(bugprone-branch-clone)
468   auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
469   torch::List<std::optional<at::Tensor>> res;
470   res.reserve(size);
471   for (const auto idx : c10::irange(size)) {
472     PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
473                           : PyList_GET_ITEM(arg.get(), idx);
474     // This is checked by the argument parser so it's safe to cast without
475     // checking if this is a tensor first
476     res.push_back(THPVariable_Unpack(obj));
477   }
478   return res;
479 }
480 
481 template <int N>
tensorlist_n(int i)482 inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
483   auto res = std::array<at::Tensor, N>();
484   if (!args[i])
485     return res;
486   auto tuple = six::isTuple(args[i]);
487   THPObjectPtr arg = six::maybeAsTuple(args[i]);
488   // NOLINTNEXTLINE(bugprone-branch-clone)
489   auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
490   if (size != N) {
491     throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
492   }
493   for (const auto idx : c10::irange(size)) {
494     PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
495                           : PyList_GET_ITEM(arg.get(), idx);
496     // This is checked by the argument parser so it's safe to cast without
497     // checking if this is a tensor first
498     res[idx] = THPVariable_Unpack(obj);
499   }
500   return res;
501 }
502 
intlist(int i)503 inline std::vector<int64_t> PythonArgs::intlist(int i) {
504   return intlistWithDefault(i, signature.params[i].default_intlist);
505 }
506 
toPyObject(const c10::SymInt & symint)507 inline PyObject* toPyObject(const c10::SymInt& symint) {
508   if (symint.is_symbolic()) {
509     auto r = py::cast(symint).release().ptr();
510     TORCH_INTERNAL_ASSERT(r);
511     return r;
512   } else {
513     auto m = symint.maybe_as_int();
514     return THPUtils_packInt64(*m);
515   }
516 }
517 
518 inline void throw_intlist_exception(
519     const torch::PythonArgs* args,
520     size_t i,
521     PyObject* obj,
522     size_t idx,
523     const std::exception& e = python_error()) {
524   std::string error = strlen(e.what())
525       ? e.what()
526       : std::string("type must be ") + args->signature.params[i].type_name() +
527           ",but got " + Py_TYPE(obj)->tp_name;
528   throw TypeError(
529       "%s(): argument '%s' failed to unpack the object at pos %zu with error \"%s\"",
530       args->signature.name.c_str(),
531       args->signature.params[i].name.c_str(),
532       idx + 1,
533       error.c_str());
534 }
535 
symintlist(int i)536 inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
537   if (!args[i]) {
538     return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
539       return c10::SymInt(di);
540     });
541   }
542 
543   const auto size1 = signature.params[i].size;
544   if (size1 > 0 && THPUtils_checkLong(args[i])) {
545     return std::vector<c10::SymInt>(
546         size1, c10::SymInt(THPUtils_unpackLong(args[i])));
547   }
548 
549   if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
550     auto si = py::handle(args[i]).cast<c10::SymInt>();
551     return std::vector<c10::SymInt>(size1, si);
552   }
553 
554   PyObject* arg = args[i];
555   auto tuple = PyTuple_Check(arg);
556   // NOLINTNEXTLINE(bugprone-branch-clone)
557   const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
558   std::vector<c10::SymInt> res;
559   res.reserve(size2);
560   for (const auto idx : c10::irange(size2)) {
561     PyObject* obj =
562         tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
563 
564     // Elements of torch.Size are tensors during tracing, and we need to
565     // record extra information before they are turned into an IntArrayRef
566     if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
567       auto& var = THPVariable_Unpack(obj);
568       jit::tracer::ArgumentStash::stashIntArrayRefElem(
569           signature.params[i].name, size2, idx, var);
570       try {
571         res.emplace_back(var.item<int64_t>());
572         continue;
573       } catch (std::exception& e) {
574         throw_intlist_exception(this, i, obj, idx, e);
575       }
576       continue;
577     } else {
578       // convert tensor to scalar outside of try / catch,
579       // so that Tensor subclass exceptions will not be caught.
580       if (THPUtils_checkLongExact(obj)) {
581         // Fast path for plain numbers
582         try {
583           res.emplace_back(THPUtils_unpackLong(obj));
584         } catch (std::exception& e) {
585           throw_intlist_exception(this, i, obj, idx, e);
586         }
587       } else if (THPVariable_Check(obj)) {
588         auto& var = THPVariable_Unpack(obj);
589         if (var.numel() != 1 ||
590             !at::isIntegralType(
591                 var.dtype().toScalarType(), /*include_bool*/ true)) {
592           throw_intlist_exception(this, i, obj, idx);
593         }
594         auto scalar = var.item();
595         TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
596         res.push_back(scalar.toSymInt());
597       } else {
598         try {
599           if (is_symint(py::handle(obj))) {
600             res.push_back(py::handle(obj).cast<c10::SymInt>());
601           } else {
602             res.emplace_back(THPUtils_unpackIndex(obj));
603           }
604         } catch (std::exception& e) {
605           throw_intlist_exception(this, i, obj, idx, e);
606         }
607       }
608     }
609   }
610 
611   return res;
612 }
613 
intlistWithDefault(int i,std::vector<int64_t> default_intlist)614 inline std::vector<int64_t> PythonArgs::intlistWithDefault(
615     int i,
616     std::vector<int64_t> default_intlist) {
617   if (!args[i])
618     return default_intlist;
619   PyObject* arg = args[i];
620   const auto size1 = signature.params[i].size;
621   if (size1 > 0 && THPUtils_checkLong(arg)) {
622     return std::vector<int64_t>(size1, THPUtils_unpackLong(arg));
623   }
624   if (size1 > 0 && torch::is_symint(py::handle(arg))) {
625     return std::vector<int64_t>(
626         size1,
627         py::handle(arg).cast<c10::SymInt>().guard_int(__FILE__, __LINE__));
628   }
629   auto tuple = PyTuple_Check(arg);
630   // NOLINTNEXTLINE(bugprone-branch-clone)
631   const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
632   std::vector<int64_t> res(size2);
633   for (const auto idx : c10::irange(size2)) {
634     PyObject* obj =
635         tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
636     // Elements of torch.Size are tensors during tracing, and we need to
637     // record extra information before they are turned into an IntArrayRef
638     if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
639       auto& var = THPVariable_Unpack(obj);
640       jit::tracer::ArgumentStash::stashIntArrayRefElem(
641           signature.params[i].name, size2, idx, var);
642       try {
643         res[idx] = var.item<int64_t>();
644         continue;
645       } catch (std::exception& e) {
646         throw_intlist_exception(this, i, obj, idx, e);
647       }
648     } else {
649       // convert tensor to scalar outside of try / catch,
650       // so that Tensor subclass exceptions will not be caught.
651       if (THPUtils_checkLongExact(obj)) {
652         // Fast path for plain numbers
653         try {
654           res[idx] = THPUtils_unpackLong(obj);
655         } catch (std::exception& e) {
656           throw_intlist_exception(this, i, obj, idx, e);
657         }
658       } else if (torch::is_symint(py::handle(obj))) {
659         res[idx] = py::cast<c10::SymInt>(py::handle(obj))
660                        .guard_int(__FILE__, __LINE__);
661       } else if (THPVariable_Check(obj)) {
662         auto& var = THPVariable_Unpack(obj);
663         if (var.numel() != 1 ||
664             !at::isIntegralType(
665                 var.dtype().toScalarType(), /*include_bool*/ true)) {
666           throw_intlist_exception(this, i, obj, idx);
667         }
668         res[idx] = var.item<int64_t>();
669       } else {
670         try {
671           res[idx] = THPUtils_unpackIndex(obj);
672         } catch (std::exception& e) {
673           throw_intlist_exception(this, i, obj, idx, e);
674         }
675       }
676     }
677   }
678   return res;
679 }
680 
intlistOptional(int i)681 inline c10::OptionalArray<int64_t> PythonArgs::intlistOptional(int i) {
682   if (!args[i]) {
683     return {};
684   }
685   return intlist(i);
686 }
687 
symintlistOptional(int i)688 inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
689   if (!args[i]) {
690     return {};
691   }
692   return symintlist(i);
693 }
694 
getDoublelist(int i)695 inline std::vector<double> PythonArgs::getDoublelist(int i) {
696   PyObject* arg = args[i];
697   auto tuple = PyTuple_Check(arg);
698   // NOLINTNEXTLINE(bugprone-branch-clone)
699   auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
700   std::vector<double> res(size);
701   for (const auto idx : c10::irange(size)) {
702     PyObject* obj =
703         tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
704     try {
705       res[idx] = THPUtils_unpackDouble(obj);
706     } catch (const std::exception&) {
707       throw TypeError(
708           "%s(): argument '%s' must be %s, but found element of type %s at pos %zu",
709           signature.name.c_str(),
710           signature.params[i].name.c_str(),
711           signature.params[i].type_name().c_str(),
712           Py_TYPE(obj)->tp_name,
713           idx + 1);
714     }
715   }
716   return res;
717 }
718 
doublelistOptional(int i)719 inline c10::OptionalArray<double> PythonArgs::doublelistOptional(int i) {
720   if (!args[i]) {
721     return {};
722   }
723   return this->getDoublelist(i);
724 }
725 
doublelist(int i)726 inline std::vector<double> PythonArgs::doublelist(int i) {
727   if (!args[i]) {
728     return {};
729   }
730   return this->getDoublelist(i);
731 }
732 
toDispatchKeySetOptional(int i)733 inline std::optional<c10::DispatchKeySet> PythonArgs::toDispatchKeySetOptional(
734     int i) {
735   if (!args[i]) {
736     return {};
737   }
738   return py::cast<c10::DispatchKeySet>(py::handle(args[i]));
739 }
740 
scalartypeWithDefault(int i,at::ScalarType default_scalartype)741 inline at::ScalarType PythonArgs::scalartypeWithDefault(
742     int i,
743     at::ScalarType default_scalartype) {
744   if (!args[i])
745     return default_scalartype;
746   return scalartype(i);
747 }
748 
toScalarType(PyObject * obj)749 inline at::ScalarType toScalarType(PyObject* obj) {
750   if (obj == (PyObject*)&PyFloat_Type) {
751     return at::ScalarType::Double;
752   }
753   if (obj == (PyObject*)&PyBool_Type) {
754     return at::ScalarType::Bool;
755   }
756   if (obj == (PyObject*)&PyLong_Type) {
757     return at::ScalarType::Long;
758   }
759   if (obj == (PyObject*)&PyComplex_Type) {
760     return at::ScalarType::ComplexDouble;
761   }
762   return reinterpret_cast<THPDtype*>(obj)->scalar_type;
763 }
764 
scalartype(int i)765 inline at::ScalarType PythonArgs::scalartype(int i) {
766   if (!args[i]) {
767     auto scalartype = signature.params[i].default_scalartype;
768     return (scalartype == at::ScalarType::Undefined)
769         ? torch::tensors::get_default_scalar_type()
770         : scalartype;
771   }
772   PyObject* obj = args[i];
773   return toScalarType(obj);
774 }
775 
scalartypeOptional(int i)776 inline std::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {
777   if (!args[i])
778     return std::nullopt;
779   return scalartype(i);
780 }
781 
toLayout(PyObject * obj)782 inline at::Layout toLayout(PyObject* obj) {
783   const auto layout = reinterpret_cast<THPLayout*>(obj);
784   return layout->layout;
785 }
786 
layout(int i)787 inline at::Layout PythonArgs::layout(int i) {
788   if (!args[i])
789     return signature.params[i].default_layout;
790   return toLayout(args[i]);
791 }
792 
layoutWithDefault(int i,at::Layout default_layout)793 inline at::Layout PythonArgs::layoutWithDefault(
794     int i,
795     at::Layout default_layout) {
796   if (!args[i])
797     return default_layout;
798   return layout(i);
799 }
800 
layoutOptional(int i)801 inline std::optional<at::Layout> PythonArgs::layoutOptional(int i) {
802   if (!args[i])
803     return std::nullopt;
804   return layout(i);
805 }
806 
deviceFromLong(int64_t device_index)807 inline at::Device deviceFromLong(int64_t device_index) {
808   TORCH_CHECK(device_index >= 0, "Device index must not be negative");
809   return at::Device(
810       at::getAccelerator(true).value(),
811       static_cast<c10::DeviceIndex>(device_index));
812 }
813 
toDevice(PyObject * obj)814 inline at::Device toDevice(PyObject* obj) {
815   if (THPDevice_Check(obj)) {
816     const auto device = reinterpret_cast<THPDevice*>(obj);
817     return device->device;
818   }
819   if (THPUtils_checkLong(obj)) {
820     return deviceFromLong(THPUtils_unpackLong(obj));
821   }
822   if (torch::is_symint(py::handle(obj))) {
823     auto device_index =
824         py::cast<c10::SymInt>(py::handle(obj)).guard_int(__FILE__, __LINE__);
825     return deviceFromLong(device_index);
826   }
827   const std::string& device_str = THPUtils_unpackString(obj);
828   return at::Device(device_str);
829 }
830 
device(int i)831 inline at::Device PythonArgs::device(int i) {
832   if (!args[i]) {
833     return torch::tensors::get_default_device();
834   }
835   return toDevice(args[i]);
836 }
837 
deviceWithDefault(int i,const at::Device & default_device)838 inline at::Device PythonArgs::deviceWithDefault(
839     int i,
840     const at::Device& default_device) {
841   if (!args[i])
842     return default_device;
843   return device(i);
844 }
845 
deviceOptional(int i)846 inline std::optional<at::Device> PythonArgs::deviceOptional(int i) {
847   if (!args[i])
848     return std::nullopt;
849   return device(i);
850 }
851 
dimname(int i)852 inline at::Dimname PythonArgs::dimname(int i) {
853   TORCH_INTERNAL_ASSERT(args[i] != nullptr);
854   return THPDimname_parse(args[i]);
855 }
856 
parseDimnameList(PyObject * arg)857 inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
858   auto tuple = PyTuple_Check(arg);
859   // NOLINTNEXTLINE(bugprone-branch-clone)
860   auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
861   std::vector<at::Dimname> res;
862   res.reserve(size);
863   for (const auto idx : c10::irange(size)) {
864     PyObject* obj =
865         tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
866     res.push_back(THPDimname_parse(obj));
867   }
868   return res;
869 }
870 
871 inline std::optional<std::vector<at::Dimname>> PythonArgs::
toDimnameListOptional(int i)872     toDimnameListOptional(int i) {
873   if (!args[i])
874     return std::nullopt;
875   return parseDimnameList(args[i]);
876 }
877 
dimnamelist(int i)878 inline std::vector<at::Dimname> PythonArgs::dimnamelist(int i) {
879   TORCH_INTERNAL_ASSERT(args[i]);
880   PyObject* arg = args[i];
881   auto size = signature.params[i].size;
882   TORCH_INTERNAL_ASSERT(size == 0 || size == 1);
883   if (size == 1 && THPUtils_checkDimname(arg)) {
884     return {THPDimname_parse(arg)};
885   }
886   return parseDimnameList(arg);
887 }
888 
memoryformat(int i)889 inline at::MemoryFormat PythonArgs::memoryformat(int i) {
890   if (!args[i])
891     return at::MemoryFormat::Contiguous;
892   TORCH_CHECK(
893       THPMemoryFormat_Check(args[i]),
894       "memory_format arg must be an instance of the torch.memory_format");
895   const auto memory_format = reinterpret_cast<THPMemoryFormat*>(args[i]);
896   return memory_format->memory_format;
897 }
898 
memoryformatOptional(int i)899 inline std::optional<at::MemoryFormat> PythonArgs::memoryformatOptional(int i) {
900   if (!args[i])
901     return std::nullopt;
902   return memoryformat(i);
903 }
904 
toQScheme(int i)905 inline at::QScheme PythonArgs::toQScheme(int i) {
906   if (!args[i])
907     return at::kPerTensorAffine;
908   TORCH_CHECK(
909       THPQScheme_Check(args[i]),
910       "qscheme arg must be an instance of the torch.qscheme");
911   const auto qscheme = reinterpret_cast<THPQScheme*>(args[i]);
912   return qscheme->qscheme;
913 }
914 
string(int i)915 inline std::string PythonArgs::string(int i) {
916   return stringWithDefault(i, signature.params[i].default_string);
917 }
918 
stringWithDefault(int i,const std::string & default_str)919 inline std::string PythonArgs::stringWithDefault(
920     int i,
921     const std::string& default_str) {
922   if (!args[i])
923     return default_str;
924   return THPUtils_unpackString(args[i]);
925 }
926 
stringOptional(int i)927 inline std::optional<std::string> PythonArgs::stringOptional(int i) {
928   if (!args[i])
929     return std::nullopt;
930   return THPUtils_unpackString(args[i]);
931 }
932 
stringView(int i)933 inline c10::string_view PythonArgs::stringView(int i) {
934   return stringViewWithDefault(i, signature.params[i].default_string);
935 }
936 
stringViewWithDefault(int i,const c10::string_view default_str)937 inline c10::string_view PythonArgs::stringViewWithDefault(
938     int i,
939     const c10::string_view default_str) {
940   if (!args[i])
941     return default_str;
942   return THPUtils_unpackStringView(args[i]);
943 }
944 
stringViewOptional(int i)945 inline std::optional<c10::string_view> PythonArgs::stringViewOptional(int i) {
946   if (!args[i])
947     return std::nullopt;
948   return THPUtils_unpackStringView(args[i]);
949 }
950 
toInt64(int i)951 inline int64_t PythonArgs::toInt64(int i) {
952   if (!args[i])
953     return signature.params[i].default_int;
954   if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
955     auto& var = THPVariable_Unpack(args[i]);
956     jit::tracer::ArgumentStash::stashValue(
957         signature.params[i].name, idx, var, c10::IntType::get());
958   }
959   if (torch::is_symint(py::handle(args[i]))) {
960     return py::cast<c10::SymInt>(py::handle(args[i]))
961         .guard_int(__FILE__, __LINE__);
962   }
963   return THPUtils_unpackLong(args[i]);
964 }
965 
toSymInt(int i)966 inline c10::SymInt PythonArgs::toSymInt(int i) {
967   if (!args[i]) {
968     return c10::SymInt(signature.params[i].default_int);
969   }
970 
971   if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
972     auto& var = THPVariable_Unpack(args[i]);
973     jit::tracer::ArgumentStash::stashValue(
974         signature.params[i].name, idx, var, c10::IntType::get());
975   }
976 
977   return py::cast<c10::SymInt>(py::handle(args[i]));
978 }
979 
toSymBool(int i)980 inline c10::SymBool PythonArgs::toSymBool(int i) {
981   if (!args[i]) {
982     return c10::SymBool(signature.params[i].default_bool);
983   }
984   if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
985     auto& var = THPVariable_Unpack(args[i]);
986     jit::tracer::ArgumentStash::stashValue(
987         signature.params[i].name, idx, var, c10::BoolType::get());
988   }
989 
990   return py::cast<c10::SymBool>(py::handle(args[i]));
991 }
992 
toInt64WithDefault(int i,int64_t default_int)993 inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) {
994   if (!args[i])
995     return default_int;
996   return toInt64(i);
997 }
998 
toInt64Optional(int i)999 inline std::optional<int64_t> PythonArgs::toInt64Optional(int i) {
1000   if (!args[i])
1001     return std::nullopt;
1002   return toInt64(i);
1003 }
1004 
toSymIntOptional(int i)1005 inline std::optional<c10::SymInt> PythonArgs::toSymIntOptional(int i) {
1006   if (!args[i])
1007     return std::nullopt;
1008   return toSymInt(i);
1009 }
1010 
toBoolOptional(int i)1011 inline std::optional<bool> PythonArgs::toBoolOptional(int i) {
1012   if (!args[i]) {
1013     return std::nullopt;
1014   }
1015   return toBool(i);
1016 }
1017 
toDoubleOptional(int i)1018 inline std::optional<double> PythonArgs::toDoubleOptional(int i) {
1019   if (!args[i]) {
1020     return std::nullopt;
1021   }
1022   return toDouble(i);
1023 }
1024 
toDouble(int i)1025 inline double PythonArgs::toDouble(int i) {
1026   if (!args[i])
1027     return signature.params[i].default_double;
1028   if (torch::is_symfloat(py::handle(args[i]))) {
1029     return py::cast<c10::SymFloat>(py::handle(args[i]))
1030         .guard_float(__FILE__, __LINE__);
1031   }
1032   if (torch::is_symint(py::handle(args[i]))) {
1033     return static_cast<double>(py::cast<c10::SymInt>(py::handle(args[i]))
1034                                    .guard_int(__FILE__, __LINE__));
1035   }
1036   return THPUtils_unpackDouble(args[i]);
1037 }
1038 
toBool(int i)1039 inline bool PythonArgs::toBool(int i) {
1040   if (!args[i])
1041     return signature.params[i].default_bool;
1042   if (torch::is_symbool(py::handle(args[i]))) {
1043     return py::cast<c10::SymBool>(py::handle(args[i]))
1044         .guard_bool(__FILE__, __LINE__);
1045   }
1046   return args[i] == Py_True;
1047 }
1048 
toDoubleWithDefault(int i,double default_double)1049 inline double PythonArgs::toDoubleWithDefault(int i, double default_double) {
1050   if (!args[i])
1051     return default_double;
1052   return toDouble(i);
1053 }
1054 
toComplex(int i)1055 inline c10::complex<double> PythonArgs::toComplex(int i) {
1056   if (!args[i])
1057     return *(reinterpret_cast<const c10::complex<double>*>(
1058         signature.params[i].default_complex));
1059   return THPUtils_unpackComplexDouble(args[i]);
1060 }
1061 
toComplexWithDefault(int i,c10::complex<double> default_value)1062 inline c10::complex<double> PythonArgs::toComplexWithDefault(
1063     int i,
1064     c10::complex<double> default_value) {
1065   if (!args[i])
1066     return default_value;
1067   return toComplex(i);
1068 }
1069 
toBoolWithDefault(int i,bool default_bool)1070 inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) {
1071   if (!args[i])
1072     return default_bool;
1073   return toBool(i);
1074 }
1075 
isNone(int i)1076 inline bool PythonArgs::isNone(int i) {
1077   return args[i] == nullptr;
1078 }
1079 
generator(int i)1080 inline std::optional<at::Generator> PythonArgs::generator(int i) {
1081   if (!args[i])
1082     return std::nullopt;
1083   return reinterpret_cast<THPGenerator*>(args[i])->cdata;
1084 }
1085 
storage(int i)1086 inline at::Storage PythonArgs::storage(int i) {
1087   if (!args[i])
1088     return at::Storage();
1089   return createStorage(args[i]);
1090 }
1091 
storage(int i,at::ScalarType & storage_scalar_type,bool & is_typed_storage)1092 inline at::Storage PythonArgs::storage(
1093     int i,
1094     at::ScalarType& storage_scalar_type,
1095     bool& is_typed_storage) {
1096   at::Storage storage;
1097   if (!args[i]) {
1098     storage = at::Storage();
1099     is_typed_storage = false;
1100     storage_scalar_type = at::ScalarType::Undefined;
1101   } else {
1102     std::tie(storage, storage_scalar_type, is_typed_storage) =
1103         createStorageGetType(args[i]);
1104   }
1105   return storage;
1106 }
1107 
stream(int i)1108 inline c10::Stream PythonArgs::stream(int i) {
1109   if (!args[i])
1110     return c10::Stream(
1111         c10::Stream::Default::DEFAULT, c10::Device(c10::DeviceType::CPU, -1));
1112   if (!THPStream_Check(args[i])) {
1113     throw TypeError(
1114         "expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name);
1115   }
1116   return c10::Stream::unpack3(
1117       ((THPStream*)args[i])->stream_id,
1118       static_cast<c10::DeviceIndex>(((THPStream*)args[i])->device_index),
1119       static_cast<c10::DeviceType>(((THPStream*)args[i])->device_type));
1120 }
1121 
pyobject(int i)1122 inline PyObject* PythonArgs::pyobject(int i) {
1123   if (!args[i])
1124     return Py_None;
1125   return args[i];
1126 }
1127 
1128 /*
1129  *
1130  * Handle __torch_function__ overrides if we know that there are overloaded
1131  * arguments.  All objects stored in r.overloaded_args must have a
1132  * __torch_function__ implementation and the arguments must be ordered in order
1133  * of precedence. Precedence goes from left to right in the order of the
1134  * signature of the function the overloaded arguments were passed to, except
1135  * subclasses are always considered before superclasses.
1136  *
1137  * If the result of calling __torch_function__ is NotImplemented, the
1138  * next implementation in the precedence order is called. If all
1139  * arguments return NotImplemented from their __torch_function__
1140  * implementation, a TypeError is raised in Python.
1141  *
1142  * Assumes overloaded_args has at least one entry. All entries must have
1143  * a __torch_function__ attribute that resolves to a callable that
1144  * accepts a torch API function, a tuple of arguments, and a dict of
1145  * keyword arguments for the torch API function.
1146  *
1147  * It is sufficient to call PythonArgs::has_torch_function before
1148  * calling this function to verify that there are valid arguments
1149  * present. If that is not done then special care must be taken to
1150  * ensure there are arguments that are overloaded with
1151  * __torch_function__.
1152  *
1153  * See torch._overrides.handle_torch_function for the equivalent
1154  * code in the pure-python implementation.
1155  *
1156  * 'r' is a parsed PythonArgs instance, returned from
1157  * PythonArgParser::parse.
1158  *
1159  * 'args' is a reference to the python tuple of arguments to the torch
1160  * API function.
1161  *
1162  * 'kwargs' is a reference to the python dict of keyword arguments to
1163  * the torch API function.
1164  *
1165  * 'torch_api' is a reference to a python torch API namespace.
1166  *
1167  * 'torch_api_function' is the reference to the original torch method, usually,
1168  * we can use torch_api and func_name to get torch_api_function. In some cases,
1169  * e.g., torch custom op, we create the function in C++, if we still use
1170  * torch_api and func_name to fetch original api, a cyclic call will happen.
1171  *
1172  * 'overloaded_args' is the args which have overloaded __torch_function__.
1173  *
1174  * 'func_name' is the named of the original torch method.
1175  *
1176  * TODO: we could use different names for the following 'handle_torch_function'
1177  * instead of overloading.
1178  *
1179  */
1180 // Used for Tensor methods with arguments.
1181 auto handle_torch_function(
1182     PythonArgs& r,
1183     PyObject* self,
1184     PyObject* args,
1185     PyObject* kwargs,
1186     PyObject* torch_api,
1187     const char* module_name,
1188     const char* func_name_override = nullptr) -> PyObject*;
1189 
1190 // Used for functions which needs to parse python args.
1191 auto handle_torch_function(
1192     PythonArgs& r,
1193     PyObject* args,
1194     PyObject* kwargs,
1195     PyObject* torch_api,
1196     const char* module_name,
1197     const char* func_name_override = nullptr) -> PyObject*;
1198 
1199 // Used for functions that have no argument parsing.
1200 auto handle_torch_function(
1201     PyObject* self,
1202     const std::string& func_name,
1203     PyObject* args = nullptr,
1204     PyObject* kwargs = nullptr,
1205     PyObject* torch_api = THPVariableClass,
1206     const std::string& module_name = "torch.Tensor") -> PyObject*;
1207 
1208 // Used for functions created in C++, e.g., C++ custom op, which doesn't use
1209 // PythonArgParser to get overloaded_args.
1210 enum class TorchFunctionName { TorchFunction, TorchDispatch };
1211 
1212 auto TORCH_PYTHON_API handle_torch_function_no_python_arg_parser(
1213     at::ArrayRef<PyObject*> overloaded_args,
1214     PyObject* args,
1215     PyObject* kwargs,
1216     const char* func_name,
1217     PyObject* torch_api_function,
1218     const char* module_name,
1219     TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction)
1220     -> PyObject*;
1221 
1222 // Used for getters of Tensor properties
1223 auto handle_torch_function_getter(
1224     THPVariable* self,
1225     const std::string& property_name) -> PyObject*;
1226 
1227 // Used for setters of Tensor properties.
1228 auto handle_torch_function_setter(
1229     THPVariable* self,
1230     const std::string& property_name,
1231     PyObject* value) -> int;
1232 
1233 // Used for __getitem__ and __setitem__
1234 auto handle_torch_function_indexing(
1235     PyObject* self,
1236     PyObject* index,
1237     PyObject* val = nullptr) -> PyObject*;
1238 
1239 /*
1240  * Check if the input obj is Tensor type, including its subclass, or overloaded
1241  * type. If the type defines __torch_function__, it also returns true.
1242  * Otherwise returns flase. If the class is not torch.Tensor, and it defines
1243  * __torch_function__, we append obj to overloaded_args.
1244  *
1245  * 'obj': the input argument to be checked
1246  * 'overloaded_args': the vector to append the overloaded args.
1247  */
1248 bool is_tensor_and_append_overloaded(
1249     PyObject* obj,
1250     std::vector<PyObject*>* overloaded_args);
1251 
1252 /*
1253  * Check if the input obj is Tensor List or Tensor Tuple type. First check
1254  * whether obj is Tuple or List type, if true, iterate over each element and
1255  * check whether it is Tensor type, including its subclass or overloaded type.
1256  * At the same time, the overloaded arg is appended to the overloaded_args.
1257  *
1258  * 'obj': the input argument to be checked
1259  * 'overloaded_args': the vector to append the overloaded args.
1260  * 'argnum': the number of total arguments of the function being checked.
1261  * 'throw_error': whether throw error if any element in the list or tuple is
1262  *                not tensor type or overloaded.
1263  */
1264 bool is_tensor_list_and_append_overloaded(
1265     PyObject* obj,
1266     std::vector<PyObject*>* overloaded_args,
1267     size_t argnum,
1268     bool throw_error);
1269 
1270 /* Given an argument that is definitely a tensor and is definitely overloaded,
1271  * append it to the overloaded arguments list.  Use this instead of
1272  * is_tensor_and_append_overloaded in situations where you have a PyObject
1273  * and you know it definitely is a Tensor and it is definitely overloaded.
1274  *
1275  * 'overloaded_args': the vector to append the overloaded args
1276  * 'obj': the input tensor that is overloaded
1277  */
1278 void append_overloaded_tensor(
1279     std::vector<PyObject*>* overloaded_args,
1280     PyObject* obj);
1281 
1282 /* Given an argument that is definitely a type and is definitely overloaded,
1283  * append it to the overloaded arguments list. Use this only with
1284  * __torch_dispatch__, where we operate on classes that have a
1285  * __torch_dispatch__ classmethod.
1286  *
1287  * 'overloaded_args': the vector to append the overloaded type
1288  * 'obj': the input class that has a __torch_dispatch__ classmethod.
1289  */
1290 void append_overloaded_type(
1291     std::vector<PyObject*>* overloaded_args,
1292     PyObject* obj);
1293 
1294 } // namespace torch
1295