xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/structseq.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /* Copyright Python Software Foundation
2  *
3  * This file is copy-pasted from CPython source code with modifications:
4  * https://github.com/python/cpython/blob/master/Objects/structseq.c
5  * https://github.com/python/cpython/blob/2.7/Objects/structseq.c
6  *
7  * The purpose of this file is to overwrite the default behavior
8  * of repr of structseq to provide better printting for returned
9  * structseq objects from operators, aka torch.return_types.*
10  *
11  * For more information on copyright of CPython, see:
12  * https://github.com/python/cpython#copyright-and-license-information
13  */
14 
15 #include <torch/csrc/utils/six.h>
16 #include <torch/csrc/utils/structseq.h>
17 #include <sstream>
18 
19 #include <structmember.h>
20 
21 namespace torch::utils {
22 
23 // NOTE: The built-in repr method from PyStructSequence was updated in
24 // https://github.com/python/cpython/commit/c70ab02df2894c34da2223fc3798c0404b41fd79
25 // so this function might not be required in Python 3.8+.
returned_structseq_repr(PyStructSequence * obj)26 PyObject* returned_structseq_repr(PyStructSequence* obj) {
27   PyTypeObject* typ = Py_TYPE(obj);
28   THPObjectPtr tup = six::maybeAsTuple(obj);
29   if (tup == nullptr) {
30     return nullptr;
31   }
32 
33   std::stringstream ss;
34   ss << typ->tp_name << "(\n";
35   Py_ssize_t num_elements = Py_SIZE(obj);
36 
37   for (Py_ssize_t i = 0; i < num_elements; i++) {
38     const char* cname = typ->tp_members[i].name;
39     if (cname == nullptr) {
40       PyErr_Format(
41           PyExc_SystemError,
42           "In structseq_repr(), member %zd name is nullptr"
43           " for type %.500s",
44           i,
45           typ->tp_name);
46       return nullptr;
47     }
48 
49     PyObject* val = PyTuple_GetItem(tup.get(), i);
50     if (val == nullptr) {
51       return nullptr;
52     }
53 
54     auto repr = THPObjectPtr(PyObject_Repr(val));
55     if (repr == nullptr) {
56       return nullptr;
57     }
58 
59     const char* crepr = PyUnicode_AsUTF8(repr);
60     if (crepr == nullptr) {
61       return nullptr;
62     }
63 
64     ss << cname << '=' << crepr;
65     if (i < num_elements - 1) {
66       ss << ",\n";
67     }
68   }
69   ss << ")";
70 
71   return PyUnicode_FromString(ss.str().c_str());
72 }
73 
74 } // namespace torch::utils
75