xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/autograd.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/record_function.h>
2 #include <torch/csrc/distributed/autograd/autograd.h>
3 
4 namespace torch {
5 namespace distributed {
6 namespace autograd {
7 
8 constexpr auto kDistAutogradBackwardProfilingKey =
9     "torch::distributed::autograd::backward";
10 
backward(int64_t context_id,const variable_list & roots,bool retain_graph)11 void backward(
12     int64_t context_id,
13     const variable_list& roots,
14     bool retain_graph) {
15   C10_LOG_API_USAGE_ONCE("torch.distributed.autograd.backward");
16   RECORD_FUNCTION(
17       kDistAutogradBackwardProfilingKey, std::vector<c10::IValue>());
18   try {
19     DistEngine::getInstance().execute(context_id, roots, retain_graph);
20   } catch (std::exception& e) {
21     // FIXME: crashes if exception type is not RuntimeError
22     TORCH_CHECK(false, e.what());
23   }
24 }
25 
26 } // namespace autograd
27 } // namespace distributed
28 } // namespace torch
29