xref: /aosp_15_r20/external/pytorch/test/cpp/rpc/e2e_test_base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/distributed/autograd/context/container.h>
4 #include <torch/csrc/distributed/autograd/context/context.h>
5 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
6 #include <torch/csrc/distributed/autograd/utils.h>
7 #include <torch/csrc/distributed/c10d/TCPStore.hpp>
8 #include <torch/csrc/distributed/rpc/rref_context.h>
9 #include <torch/csrc/distributed/rpc/script_call.h>
10 #include <torch/csrc/distributed/rpc/script_remote_call.h>
11 #include <torch/csrc/distributed/rpc/script_resp.h>
12 #include <torch/csrc/distributed/rpc/utils.h>
13 #include <torch/csrc/jit/runtime/operator.h>
14 
15 namespace torch {
16 namespace distributed {
17 namespace rpc {
18 
19 using torch::distributed::autograd::DistAutogradContainer;
20 using torch::distributed::autograd::DistAutogradContext;
21 
22 DistAutogradContainer* getDistAutogradContainer();
23 
24 class TestE2EBase : public ::testing::Test {
25  protected:
SetUp()26   void SetUp() override {
27     // Setup distributed autograd.
28     autogradContainer = getDistAutogradContainer();
29 
30     // Setup server store.
31     c10d::TCPStoreOptions opts{
32         /* port */ 0,
33         /* isServer */ true,
34         numWorkers,
35         /* waitWorkers */ true,
36         /* timeout */ std::chrono::seconds(10)};
37 
38     store = c10::make_intrusive<c10d::TCPStore>(serverAddress, opts);
39 
40     buildRpcAgent();
41 
42     rpcAgentPostProcessing();
43   }
44 
rpcAgentPostProcessing()45   void rpcAgentPostProcessing() {
46     RpcAgent::setCurrentRpcAgent(rpcAgent);
47     std::shared_ptr<TypeResolver> typeResolver =
48         std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
49           // For Dict that is used for device map.
50           auto pos = qn.name().find("Dict");
51           if (pos != std::string::npos) {
52             return c10::StrongTypePtr(
53                 nullptr,
54                 c10::DictType::create(
55                     c10::StringType::get(), c10::StringType::get()));
56           }
57           return c10::StrongTypePtr(
58               nullptr, c10::TensorType::create(at::Tensor()));
59         });
60     rpcAgent->setTypeResolver(typeResolver);
61     rpcAgent->start();
62   }
63 
TearDown()64   void TearDown() override {
65     rpcAgent->join();
66     rpcAgent->shutdown();
67     RpcAgent::setCurrentRpcAgent(nullptr);
68   }
69 
createRemoteRRef(at::Tensor t1,at::Tensor t2,std::shared_ptr<torch::jit::Operator> op)70   c10::intrusive_ptr<OwnerRRef> createRemoteRRef(
71       at::Tensor t1,
72       at::Tensor t2,
73       std::shared_ptr<torch::jit::Operator> op) {
74     auto& ctx = RRefContext::getInstance();
75     auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1));
76     // prevent this owner RRef being deleted due to other forks
77     ctx.addSelfAsFork(ownerRRef);
78 
79     ScriptRemoteCall scriptRemoteCall(
80         op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId());
81     auto jitFuture = autograd::sendMessageWithAutograd(
82         *rpcAgent,
83         rpcAgent->getWorkerInfo("worker"),
84         std::move(scriptRemoteCall).toMessage(),
85         false);
86 
87     ownerRRef->registerOwnerCreationFuture(jitFuture);
88 
89     // Builtin operators does not return py::object, and hence does not require
90     // GIL for destructing the potentially deleted OwerRRef.
91     jitFuture->addCallback(
92         [ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) {
93           callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId);
94         });
95     return ownerRRef;
96   }
97 
remoteAdd(at::Tensor t1,at::Tensor t2,std::shared_ptr<torch::jit::Operator> op)98   at::Tensor remoteAdd(
99       at::Tensor t1,
100       at::Tensor t2,
101       std::shared_ptr<torch::jit::Operator> op) {
102     ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1});
103 
104     // Send the RPC and return result.
105     auto response = autograd::sendMessageWithAutograd(
106         *rpcAgent,
107         rpcAgent->getWorkerInfo("worker"),
108         std::move(scriptCall).toMessage());
109     response->waitAndThrow();
110 
111     MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP;
112     auto wrappedResponse = deserializeResponse(
113         std::move(*response->value().toCustomClass<Message>()), messageType);
114     return static_cast<ScriptResp&>(*wrappedResponse).value().toTensor();
115   }
116 
117   virtual void buildRpcAgent() = 0;
118 
119   class AutogradContextGuard {
120    public:
AutogradContextGuard()121     explicit AutogradContextGuard()
122         : context(DistAutogradContainer::getInstance().newContext()) {}
123 
~AutogradContextGuard()124     ~AutogradContextGuard() {
125       DistAutogradContainer::getInstance().releaseContext(context->contextId());
126     }
127 
128    private:
129     std::shared_ptr<DistAutogradContext> context;
130   };
131 
runTrainingLoop()132   void runTrainingLoop() {
133     auto options = at::TensorOptions().requires_grad(true);
134     auto t1 = torch::ones({3, 3}, options);
135     auto t2 = torch::ones({3, 3}, options);
136 
137     c10::OperatorName full_name("aten::add", "Tensor");
138     auto matchedOp = torch::jit::findOperatorFor(full_name);
139     ASSERT_TRUE(matchedOp);
140 
141     for (size_t i = 0; i < numIters; i++) {
142       // Create the autograd context guard.
143       AutogradContextGuard guard;
144 
145       // Multiple RPCs within one autograd context for the forward pass.
146       auto result = remoteAdd(t1, t2, matchedOp);
147       for (size_t j = 0; j < 5; j++) {
148         result = remoteAdd(t1, result, matchedOp);
149       }
150 
151       auto rref = createRemoteRRef(t1, result, matchedOp);
152       result = rref->getValue().toTensor();
153 
154       // Run backward pass now.
155       autograd::DistEngine::getInstance().execute(
156           DistAutogradContainer::currentContextId(),
157           {torch::sum(result)},
158           /* retainGraph */ false);
159     }
160   }
161 
162   DistAutogradContainer* autogradContainer;
163   std::shared_ptr<RpcAgent> rpcAgent;
164   static const size_t numIters;
165   static const size_t numWorkers;
166   c10::intrusive_ptr<c10d::Store> store;
167   static const char* serverAddress;
168 };
169 
170 } // namespace rpc
171 } // namespace distributed
172 } // namespace torch
173