xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/python_cpp_function.h>
2 #include <torch/csrc/distributed/autograd/autograd.h>
3 #include <torch/csrc/jit/python/pybind_utils.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils/object_ptr.h>
6 #include <torch/csrc/utils/pybind.h>
7 #include <torch/types.h>
8 
9 namespace torch {
10 namespace distributed {
11 namespace autograd {
12 
13 namespace {
14 
15 template <typename T>
16 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
17 
dist_autograd_init(PyObject * _unused,PyObject * noargs)18 PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
19   auto autograd_module =
20       THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
21   if (!autograd_module) {
22     throw python_error();
23   }
24 
25   auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
26   if (!torch_C_module) {
27     throw python_error();
28   }
29 
30   auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
31   auto m = torch_C_m.def_submodule(
32       "_distributed_autograd", "distributed autograd bindings");
33 
34   auto module = py::handle(m).cast<py::module>();
35 
36   auto distAutogradContext =
37       shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
38           .def(
39               "_context_id",
40               &DistAutogradContext::contextId,
41               py::call_guard<py::gil_scoped_release>())
42           .def(
43               "_recv_functions",
44               [](const DistAutogradContext& ctx) {
45                 std::map<int64_t, py::object> funcs;
46                 auto recvFunctions = ctx.recvFunctions();
47 
48                 // Acquire GIL only when necessary to avoid deadlocks.
49                 pybind11::gil_scoped_acquire ag;
50                 for (const auto& map_entry : recvFunctions) {
51                   funcs.emplace(
52                       map_entry.first,
53                       py::reinterpret_steal<py::object>(
54                           torch::autograd::functionToPyObject(
55                               map_entry.second)));
56                 }
57                 return funcs;
58               },
59               py::call_guard<py::gil_scoped_release>())
60           .def(
61               "_send_functions",
62               [](const ContextPtr& ctx) {
63                 std::map<int64_t, py::object> funcs;
64                 auto sendFunctions = ctx->sendFunctions();
65 
66                 // Acquire GIL only when necessary to avoid deadlocks.
67                 pybind11::gil_scoped_acquire ag;
68                 for (const auto& map_entry : sendFunctions) {
69                   funcs.emplace(
70                       map_entry.first,
71                       py::reinterpret_steal<py::object>(
72                           torch::autograd::functionToPyObject(
73                               map_entry.second)));
74                 }
75                 return funcs;
76               },
77               py::call_guard<py::gil_scoped_release>())
78           .def(
79               "_known_worker_ids",
80               &DistAutogradContext::getKnownWorkerIds,
81               py::call_guard<py::gil_scoped_release>());
82 
83   module.def(
84       "_new_context",
85       []() -> const ContextPtr {
86         return DistAutogradContainer::getInstance().newContext();
87       },
88       py::return_value_policy::reference,
89       py::call_guard<py::gil_scoped_release>());
90 
91   module.def(
92       "_release_context",
93       [](int64_t context_id) {
94         return DistAutogradContainer::getInstance().releaseContext(context_id);
95       },
96       py::call_guard<py::gil_scoped_release>());
97 
98   module.def(
99       "_get_max_id",
100       []() { return DistAutogradContainer::getInstance().getMaxId(); },
101       py::call_guard<py::gil_scoped_release>());
102 
103   module.def(
104       "_is_valid_context",
105       [](int64_t worker_id) {
106         DistAutogradContainer::getInstance().isValidContext(worker_id);
107       },
108       py::call_guard<py::gil_scoped_release>());
109 
110   module.def(
111       "_retrieve_context",
112       [](int64_t context_id) -> const ContextPtr {
113         return DistAutogradContainer::getInstance().retrieveContext(context_id);
114       },
115       py::return_value_policy::reference,
116       py::call_guard<py::gil_scoped_release>());
117 
118   module.def(
119       "_current_context",
120       []() -> const ContextPtr {
121         return DistAutogradContainer::getInstance().currentContext();
122       },
123       py::return_value_policy::reference,
124       py::call_guard<py::gil_scoped_release>());
125 
126   module.def(
127       "_init",
128       [](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
129       py::call_guard<py::gil_scoped_release>());
130 
131   module.def(
132       "_get_debug_info",
133       []() { return DistEngine::getInstance().getDebugInfo(); },
134       py::call_guard<py::gil_scoped_release>());
135 
136   py::options options;
137   options.disable_function_signatures();
138 
139   module.def(
140       "backward",
141       backward,
142       R"(
143 backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None
144 
145 Kicks off the distributed backward pass using the provided roots. This
146 currently implements the :ref:`fast-mode-algorithm` which
147 assumes all RPC messages sent in the same distributed autograd context
148 across workers would be part of the autograd graph during the backward pass.
149 
150 We use the provided roots to discover the autograd graph and compute
151 appropriate dependencies. This method blocks until the entire
152 autograd computation is done.
153 
154 We accumulate the gradients in the appropriate
155 :class:`torch.distributed.autograd.context` on each of the nodes. The autograd
156 context to be used is looked up given the ``context_id`` that is passed in when
157 :meth:`torch.distributed.autograd.backward` is called. If there is no valid
158 autograd context corresponding to the given ID, we throw an error. You can
159 retrieve the accumulated gradients using the
160 :meth:`~torch.distributed.autograd.get_gradients` API.
161 
162 Arguments:
163     context_id (int): The autograd context id for which we should retrieve the gradients.
164     roots (list): Tensors which represent the roots of the autograd
165                   computation. All the tensors should be scalars.
166     retain_graph(bool, optional): If False, the graph used to compute the grad
167                   will be freed. Note that in nearly all cases setting this
168                   option to True is not needed and often can be worked around
169                   in a much more efficient way. Usually, you need to set this
170                   to True to run backward multiple times.
171 
172 Example::
173     >>> import torch.distributed.autograd as dist_autograd
174     >>> with dist_autograd.context() as context_id:
175     >>>     pred = model.forward()
176     >>>     loss = loss_func(pred, loss)
177     >>>     dist_autograd.backward(context_id, loss)
178 )",
179       py::arg("contextId"),
180       py::arg("roots"),
181       py::arg("retain_graph") = false,
182       py::call_guard<py::gil_scoped_release>());
183 
184   module.def(
185       "get_gradients",
186       [](int64_t contextId) -> py::dict {
187         const auto& autogradContext =
188             DistAutogradContainer::getInstance().retrieveContext(contextId);
189         auto ival = IValue(autogradContext->getGradients());
190 
191         // Acquire GIL only for pyobject conversion.
192         pybind11::gil_scoped_acquire ag;
193         return torch::jit::toPyObject(ival);
194       },
195       R"(
196 get_gradients(context_id: int) -> Dict[Tensor, Tensor]
197 
198 Retrieves a map from Tensor to the appropriate gradient for that Tensor
199 accumulated in the provided context corresponding to the given ``context_id``
200 as part of the distributed autograd backward pass.
201 
202 Arguments:
203     context_id(int): The autograd context id for which we should retrieve the
204                      gradients.
205 
206 Returns:
207     A map where the key is the Tensor and the value is the associated gradient
208     for that Tensor.
209 
210 Example::
211     >>> import torch.distributed.autograd as dist_autograd
212     >>> with dist_autograd.context() as context_id:
213     >>>     t1 = torch.rand((3, 3), requires_grad=True)
214     >>>     t2 = torch.rand((3, 3), requires_grad=True)
215     >>>     loss = t1 + t2
216     >>>     dist_autograd.backward(context_id, [loss.sum()])
217     >>>     grads = dist_autograd.get_gradients(context_id)
218     >>>     print(grads[t1])
219     >>>     print(grads[t2])
220 )",
221       py::arg("context_id"),
222       py::call_guard<py::gil_scoped_release>());
223 
224   Py_RETURN_TRUE;
225 }
226 } // namespace
227 
228 static PyMethodDef methods[] = { // NOLINT
229     {"_dist_autograd_init", dist_autograd_init, METH_NOARGS, nullptr},
230     {nullptr, nullptr, 0, nullptr}};
231 
python_functions()232 PyMethodDef* python_functions() {
233   return methods;
234 }
235 
236 } // namespace autograd
237 } // namespace distributed
238 } // namespace torch
239