1 #include <c10/util/flat_hash_map.h>
2 #include <torch/csrc/Exceptions.h>
3 #include <torch/csrc/python_dimname.h>
4 #include <torch/csrc/utils/python_strings.h>
5
6 namespace torch {
7
8 struct InternedStringsTable {
9 InternedStringsTable() = default;
10 // NOLINTNEXTLINE(bugprone-exception-escape)
11 ~InternedStringsTable();
12 InternedStringsTable(const InternedStringsTable&) = delete;
13 InternedStringsTable& operator=(InternedStringsTable const&) = delete;
14 InternedStringsTable(InternedStringsTable&&) = delete;
15 InternedStringsTable& operator=(InternedStringsTable&&) = delete;
16
17 std::optional<at::Dimname> lookup(PyObject* obj);
18 // Precondition: obj is an interned python string.
19 void addMapping(PyObject* obj, at::Dimname dimname);
20
21 private:
22 ska::flat_hash_map<PyObject*, at::Dimname> py_interned_string_to_dimname_;
23 };
24
25 InternedStringsTable kPyInternedStringToDimname;
26
27 // NOLINTNEXTLINE(bugprone-exception-escape)
~InternedStringsTable()28 InternedStringsTable::~InternedStringsTable() {
29 // If python is already dead, leak the wrapped python objects
30 if (Py_IsInitialized()) {
31 pybind11::gil_scoped_acquire gil;
32 for (auto it = py_interned_string_to_dimname_.begin();
33 it != py_interned_string_to_dimname_.end();
34 ++it) {
35 // See Note [References to python interned strings]
36 Py_DECREF(it->first);
37 }
38 }
39 }
40
lookup(PyObject * obj)41 std::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) {
42 auto it = py_interned_string_to_dimname_.find(obj);
43 if (it == py_interned_string_to_dimname_.end()) {
44 return std::nullopt;
45 }
46 return it->second;
47 }
48
addMapping(PyObject * obj,at::Dimname dimname)49 void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) {
50 // Note [References to python interned strings]
51 // If a Python interned string has no references to it, then it gets
52 // deallocated, invalidating this mapping. Let's immortalize the string by
53 // holding a refcount to it and releasing it in the destructor
54 Py_INCREF(obj);
55 py_interned_string_to_dimname_.emplace(obj, dimname);
56 }
57
58 } // namespace torch
59
THPUtils_checkDimname(PyObject * obj)60 bool THPUtils_checkDimname(PyObject* obj) {
61 return obj == Py_None || THPUtils_checkString(obj);
62 }
63
64 // To avoid ambiguity with IntArrayRef, we parse obj as a DimnameList if
65 // it is a list or tuple and its first elt is a Dimname
THPUtils_checkDimnameList(PyObject * obj)66 bool THPUtils_checkDimnameList(PyObject* obj) {
67 auto tuple = PyTuple_Check(obj);
68 if (!tuple && !PyList_Check(obj)) {
69 return false;
70 }
71 // NOLINTNEXTLINE(bugprone-branch-clone)
72 const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
73 if (size == 0) {
74 return true;
75 }
76 PyObject* first_elt =
77 tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
78 return THPUtils_checkDimname(first_elt);
79 }
80
THPDimname_parse(PyObject * obj)81 at::Dimname THPDimname_parse(PyObject* obj) {
82 if (obj == Py_None) {
83 return at::Dimname::wildcard();
84 }
85
86 TORCH_CHECK_TYPE(
87 THPUtils_checkString(obj),
88 "expected None or string for Dimname but got ",
89 Py_TYPE(obj)->tp_name);
90
91 if (!THPUtils_isInterned(obj)) {
92 // internStringInPlace decrefs obj and increfs the result. Because we're
93 // not actually returning the result to the user, we need to undo these.
94 // See
95 // https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace
96 Py_INCREF(obj);
97 THPUtils_internStringInPlace(&obj);
98 Py_DECREF(obj);
99 }
100
101 auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj);
102 if (maybeDimname) {
103 return *maybeDimname;
104 }
105
106 const auto name = THPUtils_unpackString(obj);
107 auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name));
108 torch::kPyInternedStringToDimname.addMapping(obj, dimname);
109 return dimname;
110 }
111