xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
2 
3 namespace torch {
4 namespace distributed {
5 namespace autograd {
6 
apply(torch::autograd::variable_list && inputs)7 torch::autograd::variable_list SendRpcBackward::apply(
8     torch::autograd::variable_list&& inputs) {
9   TORCH_INTERNAL_ASSERT(
10       inputs.empty(), "SendRpcBackward should receive no inputs");
11 
12   // Each grad variable should be valid!
13   for (const auto& grad : grads_) {
14     TORCH_INTERNAL_ASSERT(
15         grad.defined(), "BUG!: SendRpcBackward didn't receive valid gradients");
16   }
17 
18   // Simply forwards the gradients over.
19   return std::move(grads_);
20 }
21 
setGrads(const torch::autograd::variable_list & grads)22 void SendRpcBackward::setGrads(const torch::autograd::variable_list& grads) {
23   grads_ = grads;
24 }
25 
getGrads() const26 const torch::autograd::variable_list& SendRpcBackward::getGrads() const {
27   return grads_;
28 }
29 
30 } // namespace autograd
31 } // namespace distributed
32 } // namespace torch
33