1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/Device.h>
3 #include <torch/csrc/Event.h>
4 #include <torch/csrc/Stream.h>
5 #include <torch/csrc/THP.h>
6 #include <torch/csrc/utils/pybind.h>
7 #include <torch/csrc/utils/pycfunction_helpers.h>
8 #include <torch/csrc/utils/python_arg_parser.h>
9
10 #include <c10/core/Event.h>
11 #include <c10/core/Stream.h>
12
13 #include <c10/core/DeviceType.h>
14 #include <c10/core/impl/DeviceGuardImplInterface.h>
15 #include <structmember.h>
16 #include <string>
17
18 PyObject* THPEventClass = nullptr;
19
THPEvent_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)20 static PyObject* THPEvent_pynew(
21 PyTypeObject* type,
22 PyObject* args,
23 PyObject* kwargs) {
24 HANDLE_TH_ERRORS
25
26 unsigned char enable_timing = 0;
27 unsigned char blocking = 0;
28 unsigned char interprocess = 0;
29
30 static torch::PythonArgParser parser({
31 "Event(Device device=None, *, bool enable_timing=True, bool blocking=False, bool interprocess=False)",
32 });
33
34 torch::ParsedArgs<4> parsed_args;
35 auto r = parser.parse(args, kwargs, parsed_args);
36
37 auto device = r.deviceOptional(0);
38
39 if (!device.has_value()) {
40 device = at::Device(at::getAccelerator(false).value_or(at::kCPU));
41 }
42 enable_timing = r.toBoolWithDefault(1, true);
43 blocking = r.toBoolWithDefault(2, false);
44 interprocess = r.toBoolWithDefault(3, false);
45
46 THPObjectPtr ptr(type->tp_alloc(type, 0));
47 if (!ptr) {
48 TORCH_CHECK(ptr, "Failed to allocate memory for Event");
49 }
50
51 THPEvent* self = (THPEvent*)ptr.get();
52
53 // TODO: blocking and interprocess are not supported yet. To support them, the
54 // flag system of c10::Event needs to be refactored. C10::Event should also
55 // provide a generic constructor to support blocking and interprocess events.
56 (void)blocking;
57 (void)interprocess;
58
59 new (&self->event) c10::Event(
60 device->type(),
61 // See note [Flags defining the behavior of events]
62 // BACKEND_DEFAULT is a enable-timing flag, and
63 // PYTORCH_DEFAULT is a disable-timing flag.
64 (enable_timing ? c10::EventFlag::BACKEND_DEFAULT
65 : c10::EventFlag::PYTORCH_DEFAULT));
66
67 return (PyObject*)ptr.release();
68 END_HANDLE_TH_ERRORS
69 }
70
THPEvent_new(c10::DeviceType device_type,c10::EventFlag flag)71 PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
72 auto type = (PyTypeObject*)&THPEventType;
73 auto self = THPObjectPtr{type->tp_alloc(type, 0)};
74 TORCH_CHECK(self, "Failed to allocate memory for Event");
75 auto self_ = reinterpret_cast<THPEvent*>(self.get());
76 new (&self_->event) c10::Event(device_type, flag);
77 return self.release();
78 }
79
THPEvent_dealloc(THPEvent * self)80 static void THPEvent_dealloc(THPEvent* self) {
81 {
82 pybind11::gil_scoped_release no_gil{};
83 self->event.~Event();
84 }
85 Py_TYPE(self)->tp_free((PyObject*)self);
86 }
87
THPEvent_get_device(THPEvent * self,void * unused)88 static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
89 HANDLE_TH_ERRORS
90 return THPDevice_New(self->event.device());
91 END_HANDLE_TH_ERRORS
92 }
93
THPEvent_record(PyObject * _self,PyObject * args,PyObject * kwargs)94 static PyObject* THPEvent_record(
95 PyObject* _self,
96 PyObject* args,
97 PyObject* kwargs) {
98 HANDLE_TH_ERRORS
99 auto self = (THPEvent*)_self;
100 PyObject* _stream = Py_None;
101 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
102 constexpr const char* accepted_args[] = {"stream", nullptr};
103 if (!PyArg_ParseTupleAndKeywords(
104 args,
105 kwargs,
106 "|O",
107 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
108 const_cast<char**>(accepted_args),
109 &_stream)) {
110 TORCH_WARN("Parsing THPEvent_record arg fails");
111 return nullptr;
112 }
113 if (_stream != Py_None) {
114 auto stream = (THPStream*)_stream;
115 self->event.record(c10::Stream::unpack3(
116 stream->stream_id,
117 stream->device_index,
118 static_cast<c10::DeviceType>(stream->device_type)));
119 } else {
120 c10::impl::VirtualGuardImpl impl{
121 static_cast<c10::DeviceType>(self->event.device_type())};
122 self->event.record(impl.getStream(impl.getDevice()));
123 }
124 Py_RETURN_NONE;
125 END_HANDLE_TH_ERRORS
126 }
127
THPEvent_from_ipc_handle(PyObject * _type,PyObject * args,PyObject * kwargs)128 static PyObject* THPEvent_from_ipc_handle(
129 PyObject* _type,
130 PyObject* args,
131 PyObject* kwargs) {
132 HANDLE_TH_ERRORS
133 auto type = (PyTypeObject*)_type;
134
135 static torch::PythonArgParser parser({
136 "from_ipc_handle(Device device, std::string ipc_handle)",
137 });
138 torch::ParsedArgs<2> parsed_args;
139 auto r = parser.parse(args, kwargs, parsed_args);
140
141 at::Device device = r.device(0);
142 TORCH_CHECK_NOT_IMPLEMENTED(
143 false,
144 "torch.Event ipc is not supported yet, please open an issue if you need this!");
145 THPObjectPtr ptr(type->tp_alloc(type, 0));
146 if (!ptr) {
147 return nullptr;
148 }
149 THPEvent* self = (THPEvent*)ptr.get();
150
151 // TODO: for constructing event from ipc handle, the c10::Event needs to have
152 // more general constructor to achieve that.
153 new (&self->event) c10::Event(device.type(), c10::EventFlag::PYTORCH_DEFAULT);
154
155 return (PyObject*)ptr.release();
156 END_HANDLE_TH_ERRORS
157 }
158
THPEvent_ipc_handle(PyObject * _self,PyObject * noargs)159 static PyObject* THPEvent_ipc_handle(
160 PyObject* _self [[maybe_unused]],
161 PyObject* noargs) {
162 HANDLE_TH_ERRORS
163 TORCH_CHECK_NOT_IMPLEMENTED(
164 false,
165 "torch.Event ipc is not supported yet, please open an issue if you need this!");
166 constexpr const char* handle = "0";
167 return PyBytes_FromStringAndSize(
168 handle, std::char_traits<char>::length(handle));
169 END_HANDLE_TH_ERRORS
170 }
171
THPEvent_wait(PyObject * _self,PyObject * args,PyObject * kwargs)172 static PyObject* THPEvent_wait(
173 PyObject* _self,
174 PyObject* args,
175 PyObject* kwargs) {
176 HANDLE_TH_ERRORS {
177 auto self = (THPEvent*)_self;
178 PyObject* _stream = Py_None;
179 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
180 constexpr const char* accepted_args[] = {"stream", nullptr};
181 if (!PyArg_ParseTupleAndKeywords(
182 args,
183 kwargs,
184 "|O",
185 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
186 const_cast<char**>(accepted_args),
187 &_stream)) {
188 TORCH_WARN("Parsing THPEvent_wait arg fails");
189 return nullptr;
190 }
191 if (_stream != Py_None) {
192 auto stream = (THPStream*)_stream;
193 self->event.block(c10::Stream::unpack3(
194 stream->stream_id,
195 stream->device_index,
196 static_cast<c10::DeviceType>(stream->device_type)));
197 } else {
198 c10::impl::VirtualGuardImpl impl{
199 static_cast<c10::DeviceType>(self->event.device_type())};
200 self->event.block(impl.getStream(impl.getDevice()));
201 }
202 }
203 Py_RETURN_NONE;
204 END_HANDLE_TH_ERRORS
205 }
206
THPEvent_query(PyObject * _self,PyObject * noargs)207 static PyObject* THPEvent_query(PyObject* _self, PyObject* noargs) {
208 HANDLE_TH_ERRORS
209 auto self = (THPEvent*)_self;
210 return PyBool_FromLong(self->event.query());
211 END_HANDLE_TH_ERRORS
212 }
213
THPEvent_elapsed_time(PyObject * _self,PyObject * _other)214 static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
215 HANDLE_TH_ERRORS
216 auto self = (THPEvent*)_self;
217 auto other = (THPEvent*)_other;
218 return PyFloat_FromDouble(self->event.elapsedTime(other->event));
219 END_HANDLE_TH_ERRORS
220 }
221
THPEvent_synchronize(PyObject * _self,PyObject * noargs)222 static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
223 HANDLE_TH_ERRORS {
224 pybind11::gil_scoped_release no_gil{};
225 auto self = (THPEvent*)_self;
226 self->event.synchronize();
227 }
228 Py_RETURN_NONE;
229 END_HANDLE_TH_ERRORS
230 }
231
THPEvent_evend_id(PyObject * _self,PyObject * noargs)232 static PyObject* THPEvent_evend_id(PyObject* _self, PyObject* noargs) {
233 HANDLE_TH_ERRORS
234 auto self = (THPEvent*)_self;
235 return PyLong_FromVoidPtr(self->event.eventId());
236 END_HANDLE_TH_ERRORS
237 }
238
THPEvent_repr(THPEvent * self)239 static PyObject* THPEvent_repr(THPEvent* self) {
240 HANDLE_TH_ERRORS
241 return THPUtils_packString(
242 "torch.Event device_type=" +
243 c10::DeviceTypeName(
244 static_cast<c10::DeviceType>(self->event.device_type()), true) +
245 ", device_index=" + std::to_string(self->event.device_index()) +
246 ", event_flag=" +
247 std::to_string(static_cast<int64_t>(self->event.flag())) + ", event_id=" +
248 std::to_string(reinterpret_cast<int64_t>(self->event.eventId())));
249 END_HANDLE_TH_ERRORS
250 }
251
252 // NOLINTNEXTLINE(*c-arrays*, *global-variables)
253 static struct PyGetSetDef THPEvent_properties[] = {
254 {"device", (getter)THPEvent_get_device, nullptr, nullptr, nullptr},
255 {"event_id", (getter)THPEvent_evend_id, nullptr, nullptr, nullptr},
256 {nullptr}};
257
258 // NOLINTNEXTLINE(*c-arrays*, *global-variables)
259 static PyMethodDef THPEvent_methods[] = {
260 {(char*)"from_ipc_handle",
261 castPyCFunctionWithKeywords(THPEvent_from_ipc_handle),
262 METH_CLASS | METH_VARARGS | METH_KEYWORDS,
263 nullptr},
264 {(char*)"record",
265 castPyCFunctionWithKeywords(THPEvent_record),
266 METH_VARARGS | METH_KEYWORDS,
267 nullptr},
268 {(char*)"wait",
269 castPyCFunctionWithKeywords(THPEvent_wait),
270 METH_VARARGS | METH_KEYWORDS,
271 nullptr},
272 {(char*)"query", THPEvent_query, METH_NOARGS, nullptr},
273 {(char*)"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr},
274 {(char*)"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
275 {(char*)"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
276 {nullptr}};
277
278 PyTypeObject THPEventType = {
279 PyVarObject_HEAD_INIT(nullptr, 0) "torch.Event", /* tp_name */
280 sizeof(THPEvent), /* tp_basicsize */
281 0, /* tp_itemsize */
282 (destructor)THPEvent_dealloc, /* tp_dealloc */
283 0, /* tp_vectorcall_offset */
284 nullptr, /* tp_getattr */
285 nullptr, /* tp_setattr */
286 nullptr, /* tp_reserved */
287 (reprfunc)THPEvent_repr, /* tp_repr */
288 nullptr, /* tp_as_number */
289 nullptr, /* tp_as_sequence */
290 nullptr, /* tp_as_mapping */
291 nullptr, /* tp_hash */
292 nullptr, /* tp_call */
293 nullptr, /* tp_str */
294 nullptr, /* tp_getattro */
295 nullptr, /* tp_setattro */
296 nullptr, /* tp_as_buffer */
297 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
298 nullptr, /* tp_doc */
299 nullptr, /* tp_traverse */
300 nullptr, /* tp_clear */
301 nullptr, /* tp_richcompare */
302 0, /* tp_weaklistoffset */
303 nullptr, /* tp_iter */
304 nullptr, /* tp_iternext */
305 THPEvent_methods, /* tp_methods */
306 nullptr, /* tp_members */
307 THPEvent_properties, /* tp_getset */
308 nullptr, /* tp_base */
309 nullptr, /* tp_dict */
310 nullptr, /* tp_descr_get */
311 nullptr, /* tp_descr_set */
312 0, /* tp_dictoffset */
313 nullptr, /* tp_init */
314 nullptr, /* tp_alloc */
315 THPEvent_pynew, /* tp_new */
316 };
317
THPEvent_init(PyObject * module)318 void THPEvent_init(PyObject* module) {
319 THPEventClass = (PyObject*)&THPEventType;
320 if (PyType_Ready(&THPEventType) < 0) {
321 throw python_error();
322 }
323 Py_INCREF(&THPEventType);
324 if (PyModule_AddObject(module, "Event", (PyObject*)&THPEventType) < 0) {
325 throw python_error();
326 }
327 }
328