xref: /aosp_15_r20/external/pytorch/torch/csrc/xpu/Event.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/Device.h>
3 #include <torch/csrc/THP.h>
4 #include <torch/csrc/utils/pybind.h>
5 #include <torch/csrc/utils/pycfunction_helpers.h>
6 #include <torch/csrc/utils/python_arg_parser.h>
7 #include <torch/csrc/xpu/Event.h>
8 #include <torch/csrc/xpu/Module.h>
9 #include <torch/csrc/xpu/Stream.h>
10 
11 #include <structmember.h>
12 
13 PyObject* THXPEventClass = nullptr;
14 
THXPEvent_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)15 static PyObject* THXPEvent_pynew(
16     PyTypeObject* type,
17     PyObject* args,
18     PyObject* kwargs) {
19   HANDLE_TH_ERRORS
20   unsigned char enable_timing = 0;
21 
22   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
23   constexpr const char* kwlist[] = {"enable_timing", nullptr};
24   if (!PyArg_ParseTupleAndKeywords(
25           args,
26           kwargs,
27           "|b",
28           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
29           const_cast<char**>(kwlist),
30           &enable_timing)) {
31     return nullptr;
32   }
33 
34   THPObjectPtr ptr(type->tp_alloc(type, 0));
35   if (!ptr) {
36     return nullptr;
37   }
38 
39   THXPEvent* self = (THXPEvent*)ptr.get();
40 
41   new (&self->xpu_event) at::xpu::XPUEvent(enable_timing);
42 
43   return (PyObject*)ptr.release();
44   END_HANDLE_TH_ERRORS
45 }
46 
THXPEvent_dealloc(THXPEvent * self)47 static void THXPEvent_dealloc(THXPEvent* self) {
48   {
49     pybind11::gil_scoped_release no_gil{};
50     self->xpu_event.~XPUEvent();
51   }
52   Py_TYPE(self)->tp_free((PyObject*)self);
53 }
54 
THXPEvent_get_sycl_event(THXPEvent * self,void * unused)55 static PyObject* THXPEvent_get_sycl_event(THXPEvent* self, void* unused) {
56   HANDLE_TH_ERRORS
57   return PyLong_FromVoidPtr(&self->xpu_event.event());
58   END_HANDLE_TH_ERRORS
59 }
60 
THXPEvent_get_device(THXPEvent * self,void * unused)61 static PyObject* THXPEvent_get_device(THXPEvent* self, void* unused) {
62   HANDLE_TH_ERRORS
63   std::optional<at::Device> device = self->xpu_event.device();
64   if (!device) {
65     Py_RETURN_NONE;
66   }
67   return THPDevice_New(device.value());
68   END_HANDLE_TH_ERRORS
69 }
70 
THXPEvent_record(PyObject * _self,PyObject * _stream)71 static PyObject* THXPEvent_record(PyObject* _self, PyObject* _stream) {
72   HANDLE_TH_ERRORS
73   auto* self = (THXPEvent*)_self;
74   auto* stream = (THXPStream*)_stream;
75   self->xpu_event.record(stream->xpu_stream);
76   Py_RETURN_NONE;
77   END_HANDLE_TH_ERRORS
78 }
79 
THXPEvent_wait(PyObject * _self,PyObject * _stream)80 static PyObject* THXPEvent_wait(PyObject* _self, PyObject* _stream) {
81   HANDLE_TH_ERRORS
82   auto* self = (THXPEvent*)_self;
83   auto* stream = (THXPStream*)_stream;
84   self->xpu_event.block(stream->xpu_stream);
85   Py_RETURN_NONE;
86   END_HANDLE_TH_ERRORS
87 }
88 
THXPEvent_query(PyObject * _self,PyObject * noargs)89 static PyObject* THXPEvent_query(PyObject* _self, PyObject* noargs) {
90   HANDLE_TH_ERRORS
91   auto* self = (THXPEvent*)_self;
92   return PyBool_FromLong(self->xpu_event.query());
93   END_HANDLE_TH_ERRORS
94 }
95 
THXPEvent_elapsed_time(PyObject * _self,PyObject * _other)96 static PyObject* THXPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
97   HANDLE_TH_ERRORS
98   auto* self = (THXPEvent*)_self;
99   auto* other = (THXPEvent*)_other;
100   return PyFloat_FromDouble(self->xpu_event.elapsed_time(other->xpu_event));
101   END_HANDLE_TH_ERRORS
102 }
103 
THXPEvent_synchronize(PyObject * _self,PyObject * noargs)104 static PyObject* THXPEvent_synchronize(PyObject* _self, PyObject* noargs) {
105   HANDLE_TH_ERRORS {
106     pybind11::gil_scoped_release no_gil;
107     auto* self = (THXPEvent*)_self;
108     self->xpu_event.synchronize();
109   }
110   Py_RETURN_NONE;
111   END_HANDLE_TH_ERRORS
112 }
113 
114 // NOLINTNEXTLINE(*c-arrays*, *global-variables)
115 static struct PyGetSetDef THXPEvent_properties[] = {
116     {"device", (getter)THXPEvent_get_device, nullptr, nullptr, nullptr},
117     {"sycl_event", (getter)THXPEvent_get_sycl_event, nullptr, nullptr, nullptr},
118     {nullptr}};
119 
120 // NOLINTNEXTLINE(*c-arrays*, *global-variables)
121 static PyMethodDef THXPEvent_methods[] = {
122     {(char*)"record", THXPEvent_record, METH_O, nullptr},
123     {(char*)"wait", THXPEvent_wait, METH_O, nullptr},
124     {(char*)"query", THXPEvent_query, METH_NOARGS, nullptr},
125     {(char*)"elapsed_time", THXPEvent_elapsed_time, METH_O, nullptr},
126     {(char*)"synchronize", THXPEvent_synchronize, METH_NOARGS, nullptr},
127     {nullptr}};
128 
129 PyTypeObject THXPEventType = {
130     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._XpuEventBase", /* tp_name */
131     sizeof(THXPEvent), /* tp_basicsize */
132     0, /* tp_itemsize */
133     (destructor)THXPEvent_dealloc, /* tp_dealloc */
134     0, /* tp_vectorcall_offset */
135     nullptr, /* tp_getattr */
136     nullptr, /* tp_setattr */
137     nullptr, /* tp_reserved */
138     nullptr, /* tp_repr */
139     nullptr, /* tp_as_number */
140     nullptr, /* tp_as_sequence */
141     nullptr, /* tp_as_mapping */
142     nullptr, /* tp_hash  */
143     nullptr, /* tp_call */
144     nullptr, /* tp_str */
145     nullptr, /* tp_getattro */
146     nullptr, /* tp_setattro */
147     nullptr, /* tp_as_buffer */
148     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
149     nullptr, /* tp_doc */
150     nullptr, /* tp_traverse */
151     nullptr, /* tp_clear */
152     nullptr, /* tp_richcompare */
153     0, /* tp_weaklistoffset */
154     nullptr, /* tp_iter */
155     nullptr, /* tp_iternext */
156     THXPEvent_methods, /* tp_methods */
157     nullptr, /* tp_members */
158     THXPEvent_properties, /* tp_getset */
159     nullptr, /* tp_base */
160     nullptr, /* tp_dict */
161     nullptr, /* tp_descr_get */
162     nullptr, /* tp_descr_set */
163     0, /* tp_dictoffset */
164     nullptr, /* tp_init */
165     nullptr, /* tp_alloc */
166     THXPEvent_pynew, /* tp_new */
167 };
168 
THXPEvent_init(PyObject * module)169 void THXPEvent_init(PyObject* module) {
170   THXPEventClass = (PyObject*)&THXPEventType;
171   if (PyType_Ready(&THXPEventType) < 0) {
172     throw python_error();
173   }
174   Py_INCREF(&THXPEventType);
175   if (PyModule_AddObject(module, "_XpuEventBase", (PyObject*)&THXPEventType) <
176       0) {
177     throw python_error();
178   }
179 }
180