1 #include <torch/csrc/python_headers.h>
2
3 #include <pybind11/chrono.h>
4
5 #include <torch/csrc/jit/python/pybind_utils.h>
6 #include <torch/csrc/utils/pybind.h>
7
8 #include <ATen/cuda/CUDAGraph.h>
9 #include <c10/cuda/CUDAGraphsC10Utils.h>
10
11 // Cargo culted partially from csrc/distributed/c10d/init.cpp
12 // and partially from csrc/cuda/Stream.cpp.
13 // THCPStream_init is also declared at global scope.
14
15 // Because THCPGraph_init is forward declared in the only consumer
16 // (csrc/Module.cpp) I don't think we need a Graph.h.
17
18 template <typename T>
19 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
20
THCPGraph_init(PyObject * module)21 void THCPGraph_init(PyObject* module) {
22 // Pybind11 patch notes say "py::module_" is more up-to-date syntax,
23 // but CI linter and some builds prefer "module".
24 auto torch_C_m = py::handle(module).cast<py::module>();
25
26 torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);
27
28 shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
29 .def(py::init<>())
30 .def(
31 "capture_begin",
32 [](::at::cuda::CUDAGraph& self,
33 std::optional<c10::cuda::MempoolId_t> pool_opt,
34 std::string capture_error_mode) {
35 cudaStreamCaptureMode capture_mode;
36 c10::cuda::MempoolId_t pool = pool_opt.has_value()
37 ? pool_opt.value()
38 : c10::cuda::MempoolId_t{0, 0};
39 if (capture_error_mode == "global") {
40 capture_mode = cudaStreamCaptureModeGlobal;
41 } else if (capture_error_mode == "thread_local") {
42 capture_mode = cudaStreamCaptureModeThreadLocal;
43 } else if (capture_error_mode == "relaxed") {
44 capture_mode = cudaStreamCaptureModeRelaxed;
45 } else {
46 TORCH_CHECK(
47 false,
48 "Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ",
49 capture_error_mode);
50 }
51 return self.capture_begin(pool, capture_mode);
52 },
53 py::arg("pool"),
54 py::arg("capture_error_mode"),
55 py::call_guard<py::gil_scoped_release>())
56 .def(
57 "capture_end",
58 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
59 .def(
60 "register_generator_state",
61 [](::at::cuda::CUDAGraph& self, py::handle raw_generator) {
62 auto generator = THPGenerator_Unwrap(raw_generator.ptr());
63 // We've unwrapped Python object to C++ object,
64 // so we could release GIL before calling into C++
65 py::gil_scoped_release release;
66 return self.register_generator_state(generator);
67 },
68 py::arg("generator"))
69 .def(
70 "replay",
71 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
72 .def(
73 "reset",
74 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset))
75 .def(
76 "pool",
77 torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool))
78 .def(
79 "debug_dump",
80 torch::wrap_pybind_function_no_gil(
81 &::at::cuda::CUDAGraph::debug_dump))
82 .def(
83 "enable_debug_mode",
84 torch::wrap_pybind_function_no_gil(
85 &::at::cuda::CUDAGraph::enable_debug_mode))
86 .def(
87 "debug_dump",
88 torch::wrap_pybind_function_no_gil(
89 &::at::cuda::CUDAGraph::debug_dump),
90 py::arg("debug_path"));
91 }
92