xref: /aosp_15_r20/external/pytorch/test/cpp/dist_autograd/test_dist_autograd.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <memory>
2 
3 #include <gtest/gtest.h>
4 
5 #include <ATen/ATen.h>
6 #include <torch/csrc/distributed/autograd/context/container.h>
7 #include <torch/csrc/distributed/autograd/context/context.h>
8 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
9 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
10 #include <torch/csrc/distributed/autograd/utils.h>
11 #include <torch/torch.h>
12 
13 namespace torch {
14 namespace distributed {
15 namespace autograd {
16 
17 class DistAutogradTest : public ::testing::Test {
18  protected:
SetUpTestCase()19   static void SetUpTestCase() {
20     autogradContainer_ = &DistAutogradContainer::init(0);
21   }
22 
TearDown()23   void TearDown() override {
24     autogradContainer_->releaseContext(
25         autogradContainer_->currentContext()->contextId());
26   }
27 
28   static DistAutogradContainer* autogradContainer_;
29 };
30 
31 DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr;
32 
TEST_F(DistAutogradTest,TestSendFunctionInvalidInputs)33 TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
34   auto options = at::TensorOptions().requires_grad(true);
35   auto in1 = torch::ones({3, 3}, options);
36   auto in2 = torch::ones({3, 3}, options);
37 
38   autogradContainer_->newContext();
39   auto autogradContext = autogradContainer_->currentContext();
40   // Attach the send autograd function to tensors.
41   std::vector<torch::Tensor> tensors = {in1, in2};
42   rpc::worker_id_t worker_id = 1;
43   addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
44   autogradContext->addKnownWorkerId(worker_id);
45   auto send_function = autogradContext->sendFunctions()[1];
46 
47   // ensure that the worker_ids are recorded
48   auto knownWorkerIds = autogradContext->getKnownWorkerIds();
49   ASSERT_TRUE(knownWorkerIds.find(worker_id) != knownWorkerIds.end());
50   ASSERT_EQ(knownWorkerIds.size(), 1);
51 
52   // This should fail since the SendRpcBackward function shouldn't receive any
53   // inputs grad.
54   EXPECT_THROW(send_function->apply({in1, in2}), c10::Error);
55 
56   // This should fail since the SendRpcBackward function encounters an undefined
57   // grad.
58   send_function->setGrads({in1, torch::autograd::Variable()});
59   EXPECT_THROW(send_function->apply({}), c10::Error);
60 }
61 
TEST_F(DistAutogradTest,TestInitializedContextCleanup)62 TEST_F(DistAutogradTest, TestInitializedContextCleanup) {
63   autogradContainer_->newContext();
64   auto contextId = autogradContainer_->currentContext()->contextId();
65   auto& engine = DistEngine::getInstance();
66   ASSERT_EQ(0, engine.numBackwardPasses());
67 
68   // Build autograd graph
69   auto x = torch::randn({2, 2}, torch::requires_grad());
70   auto y = torch::randn({2, 2}, torch::requires_grad());
71   auto z = (x * x + y * y).sum();
72   ASSERT_NE(nullptr, z.grad_fn());
73 
74   // Execute engine.
75   engine.execute(contextId, {z}, /* retainGraph */ false);
76 
77   // Validate appropriate cleanup.
78   ASSERT_EQ(0, engine.numBackwardPasses());
79 }
80 
TEST_F(DistAutogradTest,TestInitializedContextCleanupSendFunction)81 TEST_F(DistAutogradTest, TestInitializedContextCleanupSendFunction) {
82   autogradContainer_->newContext();
83   auto context = autogradContainer_->currentContext();
84   auto& engine = DistEngine::getInstance();
85   ASSERT_EQ(0, engine.numBackwardPasses());
86 
87   // Attach send function.
88   auto options = at::TensorOptions().requires_grad(true);
89   auto t = torch::ones({1}, options);
90   auto tensors = std::vector<torch::Tensor>{t};
91   addSendRpcBackward(
92       context, AutogradMetadata(context->contextId(), 0), tensors);
93 
94   auto sendFunction = context->retrieveSendFunction(0);
95   sendFunction->setGrads({t});
96 
97   // Execute engine.
98   engine
99       .executeSendFunctionAsync(context, sendFunction, /*retrainGraph*/ false)
100       ->wait();
101 
102   // Validate appropriate cleanup.
103   ASSERT_EQ(0, engine.numBackwardPasses());
104 }
105 
106 } // namespace autograd
107 } // namespace distributed
108 } // namespace torch
109