xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/Graph.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 
3 #include <pybind11/chrono.h>
4 
5 #include <torch/csrc/jit/python/pybind_utils.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 #include <ATen/cuda/CUDAGraph.h>
9 #include <c10/cuda/CUDAGraphsC10Utils.h>
10 
11 // Cargo culted partially from csrc/distributed/c10d/init.cpp
12 // and partially from csrc/cuda/Stream.cpp.
13 // THCPStream_init is also declared at global scope.
14 
15 // Because THCPGraph_init is forward declared in the only consumer
16 // (csrc/Module.cpp) I don't think we need a Graph.h.
17 
18 template <typename T>
19 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
20 
THCPGraph_init(PyObject * module)21 void THCPGraph_init(PyObject* module) {
22   // Pybind11 patch notes say "py::module_" is more up-to-date syntax,
23   // but CI linter and some builds prefer "module".
24   auto torch_C_m = py::handle(module).cast<py::module>();
25 
26   torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
27 
28   shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
29       .def(py::init<>())
30       .def(
31           "capture_begin",
32           [](::at::cuda::CUDAGraph& self,
33              std::optional<c10::cuda::MempoolId_t> pool_opt,
34              std::string capture_error_mode) {
35             cudaStreamCaptureMode capture_mode;
36             c10::cuda::MempoolId_t pool = pool_opt.has_value()
37                 ? pool_opt.value()
38                 : c10::cuda::MempoolId_t{0, 0};
39             if (capture_error_mode == "global") {
40               capture_mode = cudaStreamCaptureModeGlobal;
41             } else if (capture_error_mode == "thread_local") {
42               capture_mode = cudaStreamCaptureModeThreadLocal;
43             } else if (capture_error_mode == "relaxed") {
44               capture_mode = cudaStreamCaptureModeRelaxed;
45             } else {
46               TORCH_CHECK(
47                   false,
48                   "Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
49                   capture_error_mode);
50             }
51             return self.capture_begin(pool, capture_mode);
52           },
53           py::arg("pool"),
54           py::arg("capture_error_mode"),
55           py::call_guard<py::gil_scoped_release>())
56       .def(
57           "capture_end",
58           torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
59       .def(
60           "register_generator_state",
61           [](::at::cuda::CUDAGraph& self, py::handle raw_generator) {
62             auto generator = THPGenerator_Unwrap(raw_generator.ptr());
63             // We've unwrapped Python object to C++ object,
64             // so we could release GIL before calling into C++
65             py::gil_scoped_release release;
66             return self.register_generator_state(generator);
67           },
68           py::arg("generator"))
69       .def(
70           "replay",
71           torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
72       .def(
73           "reset",
74           torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset))
75       .def(
76           "pool",
77           torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool))
78       .def(
79           "debug_dump",
80           torch::wrap_pybind_function_no_gil(
81               &::at::cuda::CUDAGraph::debug_dump))
82       .def(
83           "enable_debug_mode",
84           torch::wrap_pybind_function_no_gil(
85               &::at::cuda::CUDAGraph::enable_debug_mode))
86       .def(
87           "debug_dump",
88           torch::wrap_pybind_function_no_gil(
89               &::at::cuda::CUDAGraph::debug_dump),
90           py::arg("debug_path"));
91 }
92