1 #include <ATen/ATen.h>
2 #include <c10/util/CallOnce.h>
3 #include <torch/csrc/Generator.h>
4 #include <torch/csrc/THP.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/python_numbers.h>
7 #include <torch/csrc/utils/python_strings.h>
8
9 // pthread.h is included for tracking bad forks
10 #ifndef WIN32
11 #include <pthread.h>
12 #endif
13
14 namespace torch::mps {
15
16 namespace {
17 // True for children forked after mps init
18 static bool in_bad_fork = false;
19
20 // Called in the forked child if mps has already been initialized
forked_mps_child()21 static void forked_mps_child() {
22 in_bad_fork = true;
23 }
24
25 // Should be called before the first mps call.
track_bad_mps_fork()26 static void track_bad_mps_fork() {
27 #ifndef WIN32
28 static c10::once_flag flag;
29 c10::call_once(
30 flag, [] { pthread_atfork(nullptr, nullptr, forked_mps_child); });
31 #endif
32 }
33 } // namespace
34
MPSModule_isInBadFork(PyObject * self,PyObject * noargs)35 static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) {
36 HANDLE_TH_ERRORS
37 return PyBool_FromLong(in_bad_fork);
38 END_HANDLE_TH_ERRORS
39 }
40
MPSModule_getDefaultMPSGenerator(PyObject * _unused,PyObject * noargs)41 static PyObject* MPSModule_getDefaultMPSGenerator(
42 PyObject* _unused,
43 PyObject* noargs) {
44 HANDLE_TH_ERRORS
45 track_bad_mps_fork();
46 return THPGenerator_initDefaultGenerator(
47 at::detail::getMPSHooks().getDefaultMPSGenerator());
48 END_HANDLE_TH_ERRORS
49 }
50
MPSModule_isAvailable(PyObject * _unused,PyObject * noargs)51 static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
52 HANDLE_TH_ERRORS
53 track_bad_mps_fork();
54 if (at::detail::getMPSHooks().hasMPS()) {
55 Py_RETURN_TRUE;
56 } else {
57 Py_RETURN_FALSE;
58 }
59 END_HANDLE_TH_ERRORS
60 }
61
MPSModule_isMacOSorNewer(PyObject * _unused,PyObject * args)62 static PyObject* MPSModule_isMacOSorNewer(PyObject* _unused, PyObject* args) {
63 HANDLE_TH_ERRORS
64 size_t major = 0;
65 size_t minor = 0;
66 if (!PyArg_ParseTuple(args, "LL", &major, &minor)) {
67 return nullptr;
68 }
69 if (at::detail::getMPSHooks().isOnMacOSorNewer(major, minor)) {
70 Py_RETURN_TRUE;
71 } else {
72 Py_RETURN_FALSE;
73 }
74 END_HANDLE_TH_ERRORS
75 }
76
MPSModule_deviceSynchronize(PyObject * _unused,PyObject * noargs)77 static PyObject* MPSModule_deviceSynchronize(
78 PyObject* _unused,
79 PyObject* noargs) {
80 HANDLE_TH_ERRORS
81 at::detail::getMPSHooks().deviceSynchronize();
82 Py_RETURN_NONE;
83 END_HANDLE_TH_ERRORS
84 }
85
MPSModule_emptyCache(PyObject * _unused,PyObject * noargs)86 static PyObject* MPSModule_emptyCache(PyObject* _unused, PyObject* noargs) {
87 HANDLE_TH_ERRORS
88 at::detail::getMPSHooks().emptyCache();
89 Py_RETURN_NONE;
90 END_HANDLE_TH_ERRORS
91 }
92
MPSModule_setMemoryFraction(PyObject * _unused,PyObject * args)93 static PyObject* MPSModule_setMemoryFraction(
94 PyObject* _unused,
95 PyObject* args) {
96 HANDLE_TH_ERRORS
97 TORCH_CHECK(
98 THPUtils_checkDouble(args), "invalid argument to setMemoryFraction()");
99 double fraction = THPUtils_unpackDouble(args);
100 at::detail::getMPSHooks().setMemoryFraction(fraction);
101 Py_RETURN_NONE;
102 END_HANDLE_TH_ERRORS
103 }
104
MPSModule_currentAllocatedMemory(PyObject * _unused,PyObject * noargs)105 static PyObject* MPSModule_currentAllocatedMemory(
106 PyObject* _unused,
107 PyObject* noargs) {
108 HANDLE_TH_ERRORS
109 return THPUtils_packUInt64(
110 at::detail::getMPSHooks().getCurrentAllocatedMemory());
111 END_HANDLE_TH_ERRORS
112 }
113
MPSModule_driverAllocatedMemory(PyObject * _unused,PyObject * noargs)114 static PyObject* MPSModule_driverAllocatedMemory(
115 PyObject* _unused,
116 PyObject* noargs) {
117 HANDLE_TH_ERRORS
118 return THPUtils_packUInt64(
119 at::detail::getMPSHooks().getDriverAllocatedMemory());
120 END_HANDLE_TH_ERRORS
121 }
122
MPSModule_recommendedMaxMemory(PyObject * _unused,PyObject * noargs)123 static PyObject* MPSModule_recommendedMaxMemory(
124 PyObject* _unused,
125 PyObject* noargs) {
126 HANDLE_TH_ERRORS
127 return THPUtils_packUInt64(
128 at::detail::getMPSHooks().getRecommendedMaxMemory());
129 END_HANDLE_TH_ERRORS
130 }
131
MPSModule_profilerStartTrace(PyObject * _unused,PyObject * args)132 static PyObject* MPSModule_profilerStartTrace(
133 PyObject* _unused,
134 PyObject* args) {
135 HANDLE_TH_ERRORS
136 PyObject* mode_string_o = nullptr;
137 PyObject* wait_until_completed_string_o = nullptr;
138 if (!PyArg_ParseTuple(
139 args, "OO", &mode_string_o, &wait_until_completed_string_o)) {
140 return nullptr;
141 }
142 const std::string mode = THPUtils_unpackString(mode_string_o);
143 const bool waitUntilCompleted =
144 THPUtils_unpackBool(wait_until_completed_string_o);
145 at::detail::getMPSHooks().profilerStartTrace(mode, waitUntilCompleted);
146 Py_RETURN_NONE;
147 END_HANDLE_TH_ERRORS
148 }
149
MPSModule_profilerStopTrace(PyObject * _unused,PyObject * noargs)150 static PyObject* MPSModule_profilerStopTrace(
151 PyObject* _unused,
152 PyObject* noargs) {
153 HANDLE_TH_ERRORS
154 at::detail::getMPSHooks().profilerStopTrace();
155 Py_RETURN_NONE;
156 END_HANDLE_TH_ERRORS
157 }
158
MPSModule_acquireEvent(PyObject * _unused,PyObject * args)159 static PyObject* MPSModule_acquireEvent(PyObject* _unused, PyObject* args) {
160 HANDLE_TH_ERRORS
161 const bool enable_timing = THPUtils_unpackBool(args);
162 return THPUtils_packUInt32(
163 at::detail::getMPSHooks().acquireEvent(enable_timing));
164 END_HANDLE_TH_ERRORS
165 }
166
MPSModule_releaseEvent(PyObject * _unused,PyObject * args)167 static PyObject* MPSModule_releaseEvent(PyObject* _unused, PyObject* args) {
168 HANDLE_TH_ERRORS
169 const uint32_t event_id = THPUtils_unpackUInt32(args);
170 at::detail::getMPSHooks().releaseEvent(event_id);
171 Py_RETURN_NONE;
172 END_HANDLE_TH_ERRORS
173 }
174
MPSModule_recordEvent(PyObject * _unused,PyObject * args)175 static PyObject* MPSModule_recordEvent(PyObject* _unused, PyObject* args) {
176 HANDLE_TH_ERRORS
177 const uint32_t event_id = THPUtils_unpackUInt32(args);
178 at::detail::getMPSHooks().recordEvent(event_id);
179 Py_RETURN_NONE;
180 END_HANDLE_TH_ERRORS
181 }
182
MPSModule_waitForEvent(PyObject * _unused,PyObject * args)183 static PyObject* MPSModule_waitForEvent(PyObject* _unused, PyObject* args) {
184 HANDLE_TH_ERRORS
185 const uint32_t event_id = THPUtils_unpackUInt32(args);
186 at::detail::getMPSHooks().waitForEvent(event_id);
187 Py_RETURN_NONE;
188 END_HANDLE_TH_ERRORS
189 }
190
MPSModule_synchronizeEvent(PyObject * _unused,PyObject * args)191 static PyObject* MPSModule_synchronizeEvent(PyObject* _unused, PyObject* args) {
192 HANDLE_TH_ERRORS
193 const uint32_t event_id = THPUtils_unpackUInt32(args);
194 at::detail::getMPSHooks().synchronizeEvent(event_id);
195 Py_RETURN_NONE;
196 END_HANDLE_TH_ERRORS
197 }
198
MPSModule_queryEvent(PyObject * _unused,PyObject * args)199 static PyObject* MPSModule_queryEvent(PyObject* _unused, PyObject* args) {
200 HANDLE_TH_ERRORS
201 const uint32_t event_id = THPUtils_unpackUInt32(args);
202
203 if (at::detail::getMPSHooks().queryEvent(event_id)) {
204 Py_RETURN_TRUE;
205 } else {
206 Py_RETURN_FALSE;
207 }
208 END_HANDLE_TH_ERRORS
209 }
210
MPSModule_elapsedTimeOfEvents(PyObject * _unused,PyObject * args)211 static PyObject* MPSModule_elapsedTimeOfEvents(
212 PyObject* _unused,
213 PyObject* args) {
214 HANDLE_TH_ERRORS
215 PyObject* start_event_o = nullptr;
216 PyObject* end_event_o = nullptr;
217 if (!PyArg_ParseTuple(args, "OO", &start_event_o, &end_event_o)) {
218 return nullptr;
219 }
220 const uint32_t start_event_id = THPUtils_unpackUInt32(start_event_o);
221 const uint32_t end_event_id = THPUtils_unpackUInt32(end_event_o);
222 return PyFloat_FromDouble(at::detail::getMPSHooks().elapsedTimeOfEvents(
223 start_event_id, end_event_id));
224 END_HANDLE_TH_ERRORS
225 }
226
227 // NOLINTNEXTLINE(*-c-arrays, *-global-variables)
228 static struct PyMethodDef _MPSModule_methods[] = {
229 {"_mps_deviceSynchronize",
230 MPSModule_deviceSynchronize,
231 METH_NOARGS,
232 nullptr},
233 {"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
234 {"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
235 {"_mps_is_on_macos_or_newer",
236 MPSModule_isMacOSorNewer,
237 METH_VARARGS,
238 nullptr},
239 {"_mps_get_default_generator",
240 MPSModule_getDefaultMPSGenerator,
241 METH_NOARGS,
242 nullptr},
243 {"_mps_emptyCache", MPSModule_emptyCache, METH_NOARGS, nullptr},
244 {"_mps_setMemoryFraction", MPSModule_setMemoryFraction, METH_O, nullptr},
245 {"_mps_currentAllocatedMemory",
246 MPSModule_currentAllocatedMemory,
247 METH_NOARGS,
248 nullptr},
249 {"_mps_driverAllocatedMemory",
250 MPSModule_driverAllocatedMemory,
251 METH_NOARGS,
252 nullptr},
253 {"_mps_recommendedMaxMemory",
254 MPSModule_recommendedMaxMemory,
255 METH_NOARGS,
256 nullptr},
257 {"_mps_profilerStartTrace",
258 MPSModule_profilerStartTrace,
259 METH_VARARGS,
260 nullptr},
261 {"_mps_profilerStopTrace",
262 MPSModule_profilerStopTrace,
263 METH_NOARGS,
264 nullptr},
265 {"_mps_acquireEvent", MPSModule_acquireEvent, METH_O, nullptr},
266 {"_mps_releaseEvent", MPSModule_releaseEvent, METH_O, nullptr},
267 {"_mps_recordEvent", MPSModule_recordEvent, METH_O, nullptr},
268 {"_mps_waitForEvent", MPSModule_waitForEvent, METH_O, nullptr},
269 {"_mps_synchronizeEvent", MPSModule_synchronizeEvent, METH_O, nullptr},
270 {"_mps_queryEvent", MPSModule_queryEvent, METH_O, nullptr},
271 {"_mps_elapsedTimeOfEvents",
272 MPSModule_elapsedTimeOfEvents,
273 METH_VARARGS,
274 nullptr},
275 {nullptr}};
276
python_functions()277 PyMethodDef* python_functions() {
278 return _MPSModule_methods;
279 }
280
281 } // namespace torch::mps
282