xref: /aosp_15_r20/external/pytorch/torch/csrc/mps/Module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <c10/util/CallOnce.h>
3 #include <torch/csrc/Generator.h>
4 #include <torch/csrc/THP.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/python_numbers.h>
7 #include <torch/csrc/utils/python_strings.h>
8 
9 // pthread.h is included for tracking bad forks
10 #ifndef WIN32
11 #include <pthread.h>
12 #endif
13 
14 namespace torch::mps {
15 
16 namespace {
17 // True for children forked after mps init
18 static bool in_bad_fork = false;
19 
20 // Called in the forked child if mps has already been initialized
forked_mps_child()21 static void forked_mps_child() {
22   in_bad_fork = true;
23 }
24 
25 // Should be called before the first mps call.
track_bad_mps_fork()26 static void track_bad_mps_fork() {
27 #ifndef WIN32
28   static c10::once_flag flag;
29   c10::call_once(
30       flag, [] { pthread_atfork(nullptr, nullptr, forked_mps_child); });
31 #endif
32 }
33 } // namespace
34 
MPSModule_isInBadFork(PyObject * self,PyObject * noargs)35 static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) {
36   HANDLE_TH_ERRORS
37   return PyBool_FromLong(in_bad_fork);
38   END_HANDLE_TH_ERRORS
39 }
40 
MPSModule_getDefaultMPSGenerator(PyObject * _unused,PyObject * noargs)41 static PyObject* MPSModule_getDefaultMPSGenerator(
42     PyObject* _unused,
43     PyObject* noargs) {
44   HANDLE_TH_ERRORS
45   track_bad_mps_fork();
46   return THPGenerator_initDefaultGenerator(
47       at::detail::getMPSHooks().getDefaultMPSGenerator());
48   END_HANDLE_TH_ERRORS
49 }
50 
MPSModule_isAvailable(PyObject * _unused,PyObject * noargs)51 static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
52   HANDLE_TH_ERRORS
53   track_bad_mps_fork();
54   if (at::detail::getMPSHooks().hasMPS()) {
55     Py_RETURN_TRUE;
56   } else {
57     Py_RETURN_FALSE;
58   }
59   END_HANDLE_TH_ERRORS
60 }
61 
MPSModule_isMacOSorNewer(PyObject * _unused,PyObject * args)62 static PyObject* MPSModule_isMacOSorNewer(PyObject* _unused, PyObject* args) {
63   HANDLE_TH_ERRORS
64   size_t major = 0;
65   size_t minor = 0;
66   if (!PyArg_ParseTuple(args, "LL", &major, &minor)) {
67     return nullptr;
68   }
69   if (at::detail::getMPSHooks().isOnMacOSorNewer(major, minor)) {
70     Py_RETURN_TRUE;
71   } else {
72     Py_RETURN_FALSE;
73   }
74   END_HANDLE_TH_ERRORS
75 }
76 
MPSModule_deviceSynchronize(PyObject * _unused,PyObject * noargs)77 static PyObject* MPSModule_deviceSynchronize(
78     PyObject* _unused,
79     PyObject* noargs) {
80   HANDLE_TH_ERRORS
81   at::detail::getMPSHooks().deviceSynchronize();
82   Py_RETURN_NONE;
83   END_HANDLE_TH_ERRORS
84 }
85 
MPSModule_emptyCache(PyObject * _unused,PyObject * noargs)86 static PyObject* MPSModule_emptyCache(PyObject* _unused, PyObject* noargs) {
87   HANDLE_TH_ERRORS
88   at::detail::getMPSHooks().emptyCache();
89   Py_RETURN_NONE;
90   END_HANDLE_TH_ERRORS
91 }
92 
MPSModule_setMemoryFraction(PyObject * _unused,PyObject * args)93 static PyObject* MPSModule_setMemoryFraction(
94     PyObject* _unused,
95     PyObject* args) {
96   HANDLE_TH_ERRORS
97   TORCH_CHECK(
98       THPUtils_checkDouble(args), "invalid argument to setMemoryFraction()");
99   double fraction = THPUtils_unpackDouble(args);
100   at::detail::getMPSHooks().setMemoryFraction(fraction);
101   Py_RETURN_NONE;
102   END_HANDLE_TH_ERRORS
103 }
104 
MPSModule_currentAllocatedMemory(PyObject * _unused,PyObject * noargs)105 static PyObject* MPSModule_currentAllocatedMemory(
106     PyObject* _unused,
107     PyObject* noargs) {
108   HANDLE_TH_ERRORS
109   return THPUtils_packUInt64(
110       at::detail::getMPSHooks().getCurrentAllocatedMemory());
111   END_HANDLE_TH_ERRORS
112 }
113 
MPSModule_driverAllocatedMemory(PyObject * _unused,PyObject * noargs)114 static PyObject* MPSModule_driverAllocatedMemory(
115     PyObject* _unused,
116     PyObject* noargs) {
117   HANDLE_TH_ERRORS
118   return THPUtils_packUInt64(
119       at::detail::getMPSHooks().getDriverAllocatedMemory());
120   END_HANDLE_TH_ERRORS
121 }
122 
MPSModule_recommendedMaxMemory(PyObject * _unused,PyObject * noargs)123 static PyObject* MPSModule_recommendedMaxMemory(
124     PyObject* _unused,
125     PyObject* noargs) {
126   HANDLE_TH_ERRORS
127   return THPUtils_packUInt64(
128       at::detail::getMPSHooks().getRecommendedMaxMemory());
129   END_HANDLE_TH_ERRORS
130 }
131 
MPSModule_profilerStartTrace(PyObject * _unused,PyObject * args)132 static PyObject* MPSModule_profilerStartTrace(
133     PyObject* _unused,
134     PyObject* args) {
135   HANDLE_TH_ERRORS
136   PyObject* mode_string_o = nullptr;
137   PyObject* wait_until_completed_string_o = nullptr;
138   if (!PyArg_ParseTuple(
139           args, "OO", &mode_string_o, &wait_until_completed_string_o)) {
140     return nullptr;
141   }
142   const std::string mode = THPUtils_unpackString(mode_string_o);
143   const bool waitUntilCompleted =
144       THPUtils_unpackBool(wait_until_completed_string_o);
145   at::detail::getMPSHooks().profilerStartTrace(mode, waitUntilCompleted);
146   Py_RETURN_NONE;
147   END_HANDLE_TH_ERRORS
148 }
149 
MPSModule_profilerStopTrace(PyObject * _unused,PyObject * noargs)150 static PyObject* MPSModule_profilerStopTrace(
151     PyObject* _unused,
152     PyObject* noargs) {
153   HANDLE_TH_ERRORS
154   at::detail::getMPSHooks().profilerStopTrace();
155   Py_RETURN_NONE;
156   END_HANDLE_TH_ERRORS
157 }
158 
MPSModule_acquireEvent(PyObject * _unused,PyObject * args)159 static PyObject* MPSModule_acquireEvent(PyObject* _unused, PyObject* args) {
160   HANDLE_TH_ERRORS
161   const bool enable_timing = THPUtils_unpackBool(args);
162   return THPUtils_packUInt32(
163       at::detail::getMPSHooks().acquireEvent(enable_timing));
164   END_HANDLE_TH_ERRORS
165 }
166 
MPSModule_releaseEvent(PyObject * _unused,PyObject * args)167 static PyObject* MPSModule_releaseEvent(PyObject* _unused, PyObject* args) {
168   HANDLE_TH_ERRORS
169   const uint32_t event_id = THPUtils_unpackUInt32(args);
170   at::detail::getMPSHooks().releaseEvent(event_id);
171   Py_RETURN_NONE;
172   END_HANDLE_TH_ERRORS
173 }
174 
MPSModule_recordEvent(PyObject * _unused,PyObject * args)175 static PyObject* MPSModule_recordEvent(PyObject* _unused, PyObject* args) {
176   HANDLE_TH_ERRORS
177   const uint32_t event_id = THPUtils_unpackUInt32(args);
178   at::detail::getMPSHooks().recordEvent(event_id);
179   Py_RETURN_NONE;
180   END_HANDLE_TH_ERRORS
181 }
182 
MPSModule_waitForEvent(PyObject * _unused,PyObject * args)183 static PyObject* MPSModule_waitForEvent(PyObject* _unused, PyObject* args) {
184   HANDLE_TH_ERRORS
185   const uint32_t event_id = THPUtils_unpackUInt32(args);
186   at::detail::getMPSHooks().waitForEvent(event_id);
187   Py_RETURN_NONE;
188   END_HANDLE_TH_ERRORS
189 }
190 
MPSModule_synchronizeEvent(PyObject * _unused,PyObject * args)191 static PyObject* MPSModule_synchronizeEvent(PyObject* _unused, PyObject* args) {
192   HANDLE_TH_ERRORS
193   const uint32_t event_id = THPUtils_unpackUInt32(args);
194   at::detail::getMPSHooks().synchronizeEvent(event_id);
195   Py_RETURN_NONE;
196   END_HANDLE_TH_ERRORS
197 }
198 
MPSModule_queryEvent(PyObject * _unused,PyObject * args)199 static PyObject* MPSModule_queryEvent(PyObject* _unused, PyObject* args) {
200   HANDLE_TH_ERRORS
201   const uint32_t event_id = THPUtils_unpackUInt32(args);
202 
203   if (at::detail::getMPSHooks().queryEvent(event_id)) {
204     Py_RETURN_TRUE;
205   } else {
206     Py_RETURN_FALSE;
207   }
208   END_HANDLE_TH_ERRORS
209 }
210 
MPSModule_elapsedTimeOfEvents(PyObject * _unused,PyObject * args)211 static PyObject* MPSModule_elapsedTimeOfEvents(
212     PyObject* _unused,
213     PyObject* args) {
214   HANDLE_TH_ERRORS
215   PyObject* start_event_o = nullptr;
216   PyObject* end_event_o = nullptr;
217   if (!PyArg_ParseTuple(args, "OO", &start_event_o, &end_event_o)) {
218     return nullptr;
219   }
220   const uint32_t start_event_id = THPUtils_unpackUInt32(start_event_o);
221   const uint32_t end_event_id = THPUtils_unpackUInt32(end_event_o);
222   return PyFloat_FromDouble(at::detail::getMPSHooks().elapsedTimeOfEvents(
223       start_event_id, end_event_id));
224   END_HANDLE_TH_ERRORS
225 }
226 
227 // NOLINTNEXTLINE(*-c-arrays, *-global-variables)
228 static struct PyMethodDef _MPSModule_methods[] = {
229     {"_mps_deviceSynchronize",
230      MPSModule_deviceSynchronize,
231      METH_NOARGS,
232      nullptr},
233     {"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
234     {"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
235     {"_mps_is_on_macos_or_newer",
236      MPSModule_isMacOSorNewer,
237      METH_VARARGS,
238      nullptr},
239     {"_mps_get_default_generator",
240      MPSModule_getDefaultMPSGenerator,
241      METH_NOARGS,
242      nullptr},
243     {"_mps_emptyCache", MPSModule_emptyCache, METH_NOARGS, nullptr},
244     {"_mps_setMemoryFraction", MPSModule_setMemoryFraction, METH_O, nullptr},
245     {"_mps_currentAllocatedMemory",
246      MPSModule_currentAllocatedMemory,
247      METH_NOARGS,
248      nullptr},
249     {"_mps_driverAllocatedMemory",
250      MPSModule_driverAllocatedMemory,
251      METH_NOARGS,
252      nullptr},
253     {"_mps_recommendedMaxMemory",
254      MPSModule_recommendedMaxMemory,
255      METH_NOARGS,
256      nullptr},
257     {"_mps_profilerStartTrace",
258      MPSModule_profilerStartTrace,
259      METH_VARARGS,
260      nullptr},
261     {"_mps_profilerStopTrace",
262      MPSModule_profilerStopTrace,
263      METH_NOARGS,
264      nullptr},
265     {"_mps_acquireEvent", MPSModule_acquireEvent, METH_O, nullptr},
266     {"_mps_releaseEvent", MPSModule_releaseEvent, METH_O, nullptr},
267     {"_mps_recordEvent", MPSModule_recordEvent, METH_O, nullptr},
268     {"_mps_waitForEvent", MPSModule_waitForEvent, METH_O, nullptr},
269     {"_mps_synchronizeEvent", MPSModule_synchronizeEvent, METH_O, nullptr},
270     {"_mps_queryEvent", MPSModule_queryEvent, METH_O, nullptr},
271     {"_mps_elapsedTimeOfEvents",
272      MPSModule_elapsedTimeOfEvents,
273      METH_VARARGS,
274      nullptr},
275     {nullptr}};
276 
python_functions()277 PyMethodDef* python_functions() {
278   return _MPSModule_methods;
279 }
280 
281 } // namespace torch::mps
282