1 #ifdef _WIN32 2 #include <wchar.h> // _wgetenv for nvtx 3 #endif 4 #ifdef TORCH_CUDA_USE_NVTX3 5 #include <nvtx3/nvtx3.hpp> 6 #else 7 #include <nvToolsExt.h> 8 #endif 9 #include <torch/csrc/utils/pybind.h> 10 11 namespace torch::cuda::shared { 12 initNvtxBindings(PyObject * module)13void initNvtxBindings(PyObject* module) { 14 auto m = py::handle(module).cast<py::module>(); 15 16 #ifdef TORCH_CUDA_USE_NVTX3 17 auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings"); 18 #else 19 auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings"); 20 #endif 21 nvtx.def("rangePushA", nvtxRangePushA); 22 nvtx.def("rangePop", nvtxRangePop); 23 nvtx.def("rangeStartA", nvtxRangeStartA); 24 nvtx.def("rangeEnd", nvtxRangeEnd); 25 nvtx.def("markA", nvtxMarkA); 26 } 27 28 } // namespace torch::cuda::shared 29