xref: /aosp_15_r20/external/pytorch/torch/csrc/Event.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/Device.h>
3 #include <torch/csrc/Event.h>
4 #include <torch/csrc/Stream.h>
5 #include <torch/csrc/THP.h>
6 #include <torch/csrc/utils/pybind.h>
7 #include <torch/csrc/utils/pycfunction_helpers.h>
8 #include <torch/csrc/utils/python_arg_parser.h>
9 
10 #include <c10/core/Event.h>
11 #include <c10/core/Stream.h>
12 
13 #include <c10/core/DeviceType.h>
14 #include <c10/core/impl/DeviceGuardImplInterface.h>
15 #include <structmember.h>
16 #include <string>
17 
18 PyObject* THPEventClass = nullptr;
19 
THPEvent_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)20 static PyObject* THPEvent_pynew(
21     PyTypeObject* type,
22     PyObject* args,
23     PyObject* kwargs) {
24   HANDLE_TH_ERRORS
25 
26   unsigned char enable_timing = 0;
27   unsigned char blocking = 0;
28   unsigned char interprocess = 0;
29 
30   static torch::PythonArgParser parser({
31       "Event(Device device=None, *, bool enable_timing=True, bool blocking=False, bool interprocess=False)",
32   });
33 
34   torch::ParsedArgs<4> parsed_args;
35   auto r = parser.parse(args, kwargs, parsed_args);
36 
37   auto device = r.deviceOptional(0);
38 
39   if (!device.has_value()) {
40     device = at::Device(at::getAccelerator(false).value_or(at::kCPU));
41   }
42   enable_timing = r.toBoolWithDefault(1, true);
43   blocking = r.toBoolWithDefault(2, false);
44   interprocess = r.toBoolWithDefault(3, false);
45 
46   THPObjectPtr ptr(type->tp_alloc(type, 0));
47   if (!ptr) {
48     TORCH_CHECK(ptr, "Failed to allocate memory for Event");
49   }
50 
51   THPEvent* self = (THPEvent*)ptr.get();
52 
53   // TODO: blocking and interprocess are not supported yet. To support them, the
54   // flag system of c10::Event needs to be refactored. C10::Event should also
55   // provide a generic constructor to support blocking and interprocess events.
56   (void)blocking;
57   (void)interprocess;
58 
59   new (&self->event) c10::Event(
60       device->type(),
61       // See note [Flags defining the behavior of events]
62       // BACKEND_DEFAULT is a enable-timing flag, and
63       // PYTORCH_DEFAULT is a disable-timing flag.
64       (enable_timing ? c10::EventFlag::BACKEND_DEFAULT
65                      : c10::EventFlag::PYTORCH_DEFAULT));
66 
67   return (PyObject*)ptr.release();
68   END_HANDLE_TH_ERRORS
69 }
70 
THPEvent_new(c10::DeviceType device_type,c10::EventFlag flag)71 PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
72   auto type = (PyTypeObject*)&THPEventType;
73   auto self = THPObjectPtr{type->tp_alloc(type, 0)};
74   TORCH_CHECK(self, "Failed to allocate memory for Event");
75   auto self_ = reinterpret_cast<THPEvent*>(self.get());
76   new (&self_->event) c10::Event(device_type, flag);
77   return self.release();
78 }
79 
THPEvent_dealloc(THPEvent * self)80 static void THPEvent_dealloc(THPEvent* self) {
81   {
82     pybind11::gil_scoped_release no_gil{};
83     self->event.~Event();
84   }
85   Py_TYPE(self)->tp_free((PyObject*)self);
86 }
87 
THPEvent_get_device(THPEvent * self,void * unused)88 static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
89   HANDLE_TH_ERRORS
90   return THPDevice_New(self->event.device());
91   END_HANDLE_TH_ERRORS
92 }
93 
THPEvent_record(PyObject * _self,PyObject * args,PyObject * kwargs)94 static PyObject* THPEvent_record(
95     PyObject* _self,
96     PyObject* args,
97     PyObject* kwargs) {
98   HANDLE_TH_ERRORS
99   auto self = (THPEvent*)_self;
100   PyObject* _stream = Py_None;
101   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
102   constexpr const char* accepted_args[] = {"stream", nullptr};
103   if (!PyArg_ParseTupleAndKeywords(
104           args,
105           kwargs,
106           "|O",
107           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
108           const_cast<char**>(accepted_args),
109           &_stream)) {
110     TORCH_WARN("Parsing THPEvent_record arg fails");
111     return nullptr;
112   }
113   if (_stream != Py_None) {
114     auto stream = (THPStream*)_stream;
115     self->event.record(c10::Stream::unpack3(
116         stream->stream_id,
117         stream->device_index,
118         static_cast<c10::DeviceType>(stream->device_type)));
119   } else {
120     c10::impl::VirtualGuardImpl impl{
121         static_cast<c10::DeviceType>(self->event.device_type())};
122     self->event.record(impl.getStream(impl.getDevice()));
123   }
124   Py_RETURN_NONE;
125   END_HANDLE_TH_ERRORS
126 }
127 
THPEvent_from_ipc_handle(PyObject * _type,PyObject * args,PyObject * kwargs)128 static PyObject* THPEvent_from_ipc_handle(
129     PyObject* _type,
130     PyObject* args,
131     PyObject* kwargs) {
132   HANDLE_TH_ERRORS
133   auto type = (PyTypeObject*)_type;
134 
135   static torch::PythonArgParser parser({
136       "from_ipc_handle(Device device, std::string ipc_handle)",
137   });
138   torch::ParsedArgs<2> parsed_args;
139   auto r = parser.parse(args, kwargs, parsed_args);
140 
141   at::Device device = r.device(0);
142   TORCH_CHECK_NOT_IMPLEMENTED(
143       false,
144       "torch.Event ipc is not supported yet, please open an issue if you need this!");
145   THPObjectPtr ptr(type->tp_alloc(type, 0));
146   if (!ptr) {
147     return nullptr;
148   }
149   THPEvent* self = (THPEvent*)ptr.get();
150 
151   // TODO: for constructing event from ipc handle, the c10::Event needs to have
152   // more general constructor to achieve that.
153   new (&self->event) c10::Event(device.type(), c10::EventFlag::PYTORCH_DEFAULT);
154 
155   return (PyObject*)ptr.release();
156   END_HANDLE_TH_ERRORS
157 }
158 
THPEvent_ipc_handle(PyObject * _self,PyObject * noargs)159 static PyObject* THPEvent_ipc_handle(
160     PyObject* _self [[maybe_unused]],
161     PyObject* noargs) {
162   HANDLE_TH_ERRORS
163   TORCH_CHECK_NOT_IMPLEMENTED(
164       false,
165       "torch.Event ipc is not supported yet, please open an issue if you need this!");
166   constexpr const char* handle = "0";
167   return PyBytes_FromStringAndSize(
168       handle, std::char_traits<char>::length(handle));
169   END_HANDLE_TH_ERRORS
170 }
171 
THPEvent_wait(PyObject * _self,PyObject * args,PyObject * kwargs)172 static PyObject* THPEvent_wait(
173     PyObject* _self,
174     PyObject* args,
175     PyObject* kwargs) {
176   HANDLE_TH_ERRORS {
177     auto self = (THPEvent*)_self;
178     PyObject* _stream = Py_None;
179     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
180     constexpr const char* accepted_args[] = {"stream", nullptr};
181     if (!PyArg_ParseTupleAndKeywords(
182             args,
183             kwargs,
184             "|O",
185             // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
186             const_cast<char**>(accepted_args),
187             &_stream)) {
188       TORCH_WARN("Parsing THPEvent_wait arg fails");
189       return nullptr;
190     }
191     if (_stream != Py_None) {
192       auto stream = (THPStream*)_stream;
193       self->event.block(c10::Stream::unpack3(
194           stream->stream_id,
195           stream->device_index,
196           static_cast<c10::DeviceType>(stream->device_type)));
197     } else {
198       c10::impl::VirtualGuardImpl impl{
199           static_cast<c10::DeviceType>(self->event.device_type())};
200       self->event.block(impl.getStream(impl.getDevice()));
201     }
202   }
203   Py_RETURN_NONE;
204   END_HANDLE_TH_ERRORS
205 }
206 
THPEvent_query(PyObject * _self,PyObject * noargs)207 static PyObject* THPEvent_query(PyObject* _self, PyObject* noargs) {
208   HANDLE_TH_ERRORS
209   auto self = (THPEvent*)_self;
210   return PyBool_FromLong(self->event.query());
211   END_HANDLE_TH_ERRORS
212 }
213 
THPEvent_elapsed_time(PyObject * _self,PyObject * _other)214 static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
215   HANDLE_TH_ERRORS
216   auto self = (THPEvent*)_self;
217   auto other = (THPEvent*)_other;
218   return PyFloat_FromDouble(self->event.elapsedTime(other->event));
219   END_HANDLE_TH_ERRORS
220 }
221 
THPEvent_synchronize(PyObject * _self,PyObject * noargs)222 static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
223   HANDLE_TH_ERRORS {
224     pybind11::gil_scoped_release no_gil{};
225     auto self = (THPEvent*)_self;
226     self->event.synchronize();
227   }
228   Py_RETURN_NONE;
229   END_HANDLE_TH_ERRORS
230 }
231 
THPEvent_evend_id(PyObject * _self,PyObject * noargs)232 static PyObject* THPEvent_evend_id(PyObject* _self, PyObject* noargs) {
233   HANDLE_TH_ERRORS
234   auto self = (THPEvent*)_self;
235   return PyLong_FromVoidPtr(self->event.eventId());
236   END_HANDLE_TH_ERRORS
237 }
238 
THPEvent_repr(THPEvent * self)239 static PyObject* THPEvent_repr(THPEvent* self) {
240   HANDLE_TH_ERRORS
241   return THPUtils_packString(
242       "torch.Event device_type=" +
243       c10::DeviceTypeName(
244           static_cast<c10::DeviceType>(self->event.device_type()), true) +
245       ", device_index=" + std::to_string(self->event.device_index()) +
246       ", event_flag=" +
247       std::to_string(static_cast<int64_t>(self->event.flag())) + ", event_id=" +
248       std::to_string(reinterpret_cast<int64_t>(self->event.eventId())));
249   END_HANDLE_TH_ERRORS
250 }
251 
252 // NOLINTNEXTLINE(*c-arrays*, *global-variables)
253 static struct PyGetSetDef THPEvent_properties[] = {
254     {"device", (getter)THPEvent_get_device, nullptr, nullptr, nullptr},
255     {"event_id", (getter)THPEvent_evend_id, nullptr, nullptr, nullptr},
256     {nullptr}};
257 
258 // NOLINTNEXTLINE(*c-arrays*, *global-variables)
259 static PyMethodDef THPEvent_methods[] = {
260     {(char*)"from_ipc_handle",
261      castPyCFunctionWithKeywords(THPEvent_from_ipc_handle),
262      METH_CLASS | METH_VARARGS | METH_KEYWORDS,
263      nullptr},
264     {(char*)"record",
265      castPyCFunctionWithKeywords(THPEvent_record),
266      METH_VARARGS | METH_KEYWORDS,
267      nullptr},
268     {(char*)"wait",
269      castPyCFunctionWithKeywords(THPEvent_wait),
270      METH_VARARGS | METH_KEYWORDS,
271      nullptr},
272     {(char*)"query", THPEvent_query, METH_NOARGS, nullptr},
273     {(char*)"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr},
274     {(char*)"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
275     {(char*)"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
276     {nullptr}};
277 
278 PyTypeObject THPEventType = {
279     PyVarObject_HEAD_INIT(nullptr, 0) "torch.Event", /* tp_name */
280     sizeof(THPEvent), /* tp_basicsize */
281     0, /* tp_itemsize */
282     (destructor)THPEvent_dealloc, /* tp_dealloc */
283     0, /* tp_vectorcall_offset */
284     nullptr, /* tp_getattr */
285     nullptr, /* tp_setattr */
286     nullptr, /* tp_reserved */
287     (reprfunc)THPEvent_repr, /* tp_repr */
288     nullptr, /* tp_as_number */
289     nullptr, /* tp_as_sequence */
290     nullptr, /* tp_as_mapping */
291     nullptr, /* tp_hash  */
292     nullptr, /* tp_call */
293     nullptr, /* tp_str */
294     nullptr, /* tp_getattro */
295     nullptr, /* tp_setattro */
296     nullptr, /* tp_as_buffer */
297     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
298     nullptr, /* tp_doc */
299     nullptr, /* tp_traverse */
300     nullptr, /* tp_clear */
301     nullptr, /* tp_richcompare */
302     0, /* tp_weaklistoffset */
303     nullptr, /* tp_iter */
304     nullptr, /* tp_iternext */
305     THPEvent_methods, /* tp_methods */
306     nullptr, /* tp_members */
307     THPEvent_properties, /* tp_getset */
308     nullptr, /* tp_base */
309     nullptr, /* tp_dict */
310     nullptr, /* tp_descr_get */
311     nullptr, /* tp_descr_set */
312     0, /* tp_dictoffset */
313     nullptr, /* tp_init */
314     nullptr, /* tp_alloc */
315     THPEvent_pynew, /* tp_new */
316 };
317 
THPEvent_init(PyObject * module)318 void THPEvent_init(PyObject* module) {
319   THPEventClass = (PyObject*)&THPEventType;
320   if (PyType_Ready(&THPEventType) < 0) {
321     throw python_error();
322   }
323   Py_INCREF(&THPEventType);
324   if (PyModule_AddObject(module, "Event", (PyObject*)&THPEventType) < 0) {
325     throw python_error();
326   }
327 }
328