xref: /aosp_15_r20/external/pytorch/torch/csrc/python_dimname.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/flat_hash_map.h>
2 #include <torch/csrc/Exceptions.h>
3 #include <torch/csrc/python_dimname.h>
4 #include <torch/csrc/utils/python_strings.h>
5 
6 namespace torch {
7 
8 struct InternedStringsTable {
9   InternedStringsTable() = default;
10   // NOLINTNEXTLINE(bugprone-exception-escape)
11   ~InternedStringsTable();
12   InternedStringsTable(const InternedStringsTable&) = delete;
13   InternedStringsTable& operator=(InternedStringsTable const&) = delete;
14   InternedStringsTable(InternedStringsTable&&) = delete;
15   InternedStringsTable& operator=(InternedStringsTable&&) = delete;
16 
17   std::optional<at::Dimname> lookup(PyObject* obj);
18   // Precondition: obj is an interned python string.
19   void addMapping(PyObject* obj, at::Dimname dimname);
20 
21  private:
22   ska::flat_hash_map<PyObject*, at::Dimname> py_interned_string_to_dimname_;
23 };
24 
25 InternedStringsTable kPyInternedStringToDimname;
26 
27 // NOLINTNEXTLINE(bugprone-exception-escape)
~InternedStringsTable()28 InternedStringsTable::~InternedStringsTable() {
29   // If python is already dead, leak the wrapped python objects
30   if (Py_IsInitialized()) {
31     pybind11::gil_scoped_acquire gil;
32     for (auto it = py_interned_string_to_dimname_.begin();
33          it != py_interned_string_to_dimname_.end();
34          ++it) {
35       // See Note [References to python interned strings]
36       Py_DECREF(it->first);
37     }
38   }
39 }
40 
lookup(PyObject * obj)41 std::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) {
42   auto it = py_interned_string_to_dimname_.find(obj);
43   if (it == py_interned_string_to_dimname_.end()) {
44     return std::nullopt;
45   }
46   return it->second;
47 }
48 
addMapping(PyObject * obj,at::Dimname dimname)49 void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) {
50   // Note [References to python interned strings]
51   // If a Python interned string has no references to it, then it gets
52   // deallocated, invalidating this mapping. Let's immortalize the string by
53   // holding a refcount to it and releasing it in the destructor
54   Py_INCREF(obj);
55   py_interned_string_to_dimname_.emplace(obj, dimname);
56 }
57 
58 } // namespace torch
59 
THPUtils_checkDimname(PyObject * obj)60 bool THPUtils_checkDimname(PyObject* obj) {
61   return obj == Py_None || THPUtils_checkString(obj);
62 }
63 
64 // To avoid ambiguity with IntArrayRef, we parse obj as a DimnameList if
65 // it is a list or tuple and its first elt is a Dimname
THPUtils_checkDimnameList(PyObject * obj)66 bool THPUtils_checkDimnameList(PyObject* obj) {
67   auto tuple = PyTuple_Check(obj);
68   if (!tuple && !PyList_Check(obj)) {
69     return false;
70   }
71   // NOLINTNEXTLINE(bugprone-branch-clone)
72   const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
73   if (size == 0) {
74     return true;
75   }
76   PyObject* first_elt =
77       tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
78   return THPUtils_checkDimname(first_elt);
79 }
80 
THPDimname_parse(PyObject * obj)81 at::Dimname THPDimname_parse(PyObject* obj) {
82   if (obj == Py_None) {
83     return at::Dimname::wildcard();
84   }
85 
86   TORCH_CHECK_TYPE(
87       THPUtils_checkString(obj),
88       "expected None or string for Dimname but got ",
89       Py_TYPE(obj)->tp_name);
90 
91   if (!THPUtils_isInterned(obj)) {
92     // internStringInPlace decrefs obj and increfs the result. Because we're
93     // not actually returning the result to the user, we need to undo these.
94     // See
95     // https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace
96     Py_INCREF(obj);
97     THPUtils_internStringInPlace(&obj);
98     Py_DECREF(obj);
99   }
100 
101   auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj);
102   if (maybeDimname) {
103     return *maybeDimname;
104   }
105 
106   const auto name = THPUtils_unpackString(obj);
107   auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name));
108   torch::kPyInternedStringToDimname.addMapping(obj, dimname);
109   return dimname;
110 }
111