xref: /aosp_15_r20/external/pytorch/test/cpp/rpc/test_e2e_tensorpipe.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include "e2e_test_base.h"
4 
5 #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
6 #include <torch/csrc/distributed/rpc/request_callback_no_python.h>
7 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
8 #include <torch/torch.h>
9 
10 namespace torch {
11 namespace distributed {
12 namespace rpc {
13 
14 #ifdef USE_TENSORPIPE
15 
16 class TestE2ETensorPipe : public TestE2EBase {
17  protected:
buildRpcAgent()18   void buildRpcAgent() override {
19     auto options = c10d::ProcessGroupGloo::Options::create();
20     options->devices.push_back(
21         ::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress));
22     float rpcTimeout = 30;
23 
24     TensorPipeRpcBackendOptions opts(
25         /*numWorkerThreads=*/std::max(16U, std::thread::hardware_concurrency()),
26         /*transports=*/nullopt,
27         /*channels=*/nullopt,
28         /*rpc_timeout=*/rpcTimeout,
29         /*init_method=*/"unused");
30 
31     rpcAgent = std::make_shared<TensorPipeAgent>(
32         store,
33         "worker",
34         0,
35         numWorkers,
36         opts,
37         std::unordered_map<std::string, DeviceMap>{},
38         std::vector<c10::Device>{},
39         std::make_unique<RequestCallbackNoPython>());
40   }
41 };
42 
43 // End to end training loop test in C++ so that we can run LSAN on this test to
44 // catch memory leaks. Enabling LSAN with python multiprocessing has been
45 // challenging and we don't have a good solution yet.
TEST_F(TestE2ETensorPipe,TestTrainingLoop)46 TEST_F(TestE2ETensorPipe, TestTrainingLoop) {
47   runTrainingLoop();
48   // Ensure the tensorpipe internal state is cleared up.
49   auto tensorpipeAgent = std::static_pointer_cast<TensorPipeAgent>(rpcAgent);
50 
51   // Shutdown RPC agent for all RPCs to clean up.
52   tensorpipeAgent->join();
53   tensorpipeAgent->shutdown();
54   ASSERT_EQ(0, tensorpipeAgent->numPendingResponses());
55   ASSERT_EQ(0, tensorpipeAgent->timeoutMapSize());
56   ASSERT_EQ(0, tensorpipeAgent->messageIdToTimeoutMapSize());
57 }
58 
59 #endif
60 
61 } // namespace rpc
62 } // namespace distributed
63 } // namespace torch
64