xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/shared/nvtx.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef _WIN32
2 #include <wchar.h> // _wgetenv for nvtx
3 #endif
4 #ifdef TORCH_CUDA_USE_NVTX3
5 #include <nvtx3/nvtx3.hpp>
6 #else
7 #include <nvToolsExt.h>
8 #endif
9 #include <torch/csrc/utils/pybind.h>
10 
11 namespace torch::cuda::shared {
12 
initNvtxBindings(PyObject * module)13 void initNvtxBindings(PyObject* module) {
14   auto m = py::handle(module).cast<py::module>();
15 
16 #ifdef TORCH_CUDA_USE_NVTX3
17   auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings");
18 #else
19   auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings");
20 #endif
21   nvtx.def("rangePushA", nvtxRangePushA);
22   nvtx.def("rangePop", nvtxRangePop);
23   nvtx.def("rangeStartA", nvtxRangeStartA);
24   nvtx.def("rangeEnd", nvtxRangeEnd);
25   nvtx.def("markA", nvtxMarkA);
26 }
27 
28 } // namespace torch::cuda::shared
29