1 /* Copyright Python Software Foundation 2 * 3 * This file is copy-pasted from CPython source code with modifications: 4 * https://github.com/python/cpython/blob/master/Objects/structseq.c 5 * https://github.com/python/cpython/blob/2.7/Objects/structseq.c 6 * 7 * The purpose of this file is to overwrite the default behavior 8 * of repr of structseq to provide better printting for returned 9 * structseq objects from operators, aka torch.return_types.* 10 * 11 * For more information on copyright of CPython, see: 12 * https://github.com/python/cpython#copyright-and-license-information 13 */ 14 15 #include <torch/csrc/utils/six.h> 16 #include <torch/csrc/utils/structseq.h> 17 #include <sstream> 18 19 #include <structmember.h> 20 21 namespace torch::utils { 22 23 // NOTE: The built-in repr method from PyStructSequence was updated in 24 // https://github.com/python/cpython/commit/c70ab02df2894c34da2223fc3798c0404b41fd79 25 // so this function might not be required in Python 3.8+. returned_structseq_repr(PyStructSequence * obj)26PyObject* returned_structseq_repr(PyStructSequence* obj) { 27 PyTypeObject* typ = Py_TYPE(obj); 28 THPObjectPtr tup = six::maybeAsTuple(obj); 29 if (tup == nullptr) { 30 return nullptr; 31 } 32 33 std::stringstream ss; 34 ss << typ->tp_name << "(\n"; 35 Py_ssize_t num_elements = Py_SIZE(obj); 36 37 for (Py_ssize_t i = 0; i < num_elements; i++) { 38 const char* cname = typ->tp_members[i].name; 39 if (cname == nullptr) { 40 PyErr_Format( 41 PyExc_SystemError, 42 "In structseq_repr(), member %zd name is nullptr" 43 " for type %.500s", 44 i, 45 typ->tp_name); 46 return nullptr; 47 } 48 49 PyObject* val = PyTuple_GetItem(tup.get(), i); 50 if (val == nullptr) { 51 return nullptr; 52 } 53 54 auto repr = THPObjectPtr(PyObject_Repr(val)); 55 if (repr == nullptr) { 56 return nullptr; 57 } 58 59 const char* crepr = PyUnicode_AsUTF8(repr); 60 if (crepr == nullptr) { 61 return nullptr; 62 } 63 64 ss << cname << '=' << crepr; 65 if (i < num_elements - 1) { 66 ss << ",\n"; 67 } 68 } 69 ss << ")"; 70 71 return PyUnicode_FromString(ss.str().c_str()); 72 } 73 74 } // namespace torch::utils 75