xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_torch_functions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <Python.h>
2 
3 namespace torch::autograd {
4 
5 extern PyObject* THPVariableFunctionsModule;
6 
7 // Wrapper converts a raised TypeError into returning NotImplemented
8 // Used to implement binary arithmetic operators
9 template <PyObject* (*Func)(PyObject*, PyObject*, PyObject*)>
TypeError_to_NotImplemented_(PyObject * self,PyObject * args,PyObject * kwargs)10 inline PyObject* TypeError_to_NotImplemented_(
11     PyObject* self,
12     PyObject* args,
13     PyObject* kwargs) {
14   PyObject* ret = Func(self, args, kwargs);
15   if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) {
16     PyErr_Clear();
17     Py_INCREF(Py_NotImplemented);
18     ret = Py_NotImplemented;
19   }
20   return ret;
21 }
22 
23 void initTorchFunctions();
24 
25 } // namespace torch::autograd
26