1 #include <gtest/gtest.h> 2 3 #include <torch/csrc/distributed/autograd/context/container.h> 4 #include <torch/csrc/distributed/autograd/context/context.h> 5 #include <torch/csrc/distributed/autograd/engine/dist_engine.h> 6 #include <torch/csrc/distributed/autograd/utils.h> 7 #include <torch/csrc/distributed/c10d/TCPStore.hpp> 8 #include <torch/csrc/distributed/rpc/rref_context.h> 9 #include <torch/csrc/distributed/rpc/script_call.h> 10 #include <torch/csrc/distributed/rpc/script_remote_call.h> 11 #include <torch/csrc/distributed/rpc/script_resp.h> 12 #include <torch/csrc/distributed/rpc/utils.h> 13 #include <torch/csrc/jit/runtime/operator.h> 14 15 namespace torch { 16 namespace distributed { 17 namespace rpc { 18 19 using torch::distributed::autograd::DistAutogradContainer; 20 using torch::distributed::autograd::DistAutogradContext; 21 22 DistAutogradContainer* getDistAutogradContainer(); 23 24 class TestE2EBase : public ::testing::Test { 25 protected: SetUp()26 void SetUp() override { 27 // Setup distributed autograd. 28 autogradContainer = getDistAutogradContainer(); 29 30 // Setup server store. 31 c10d::TCPStoreOptions opts{ 32 /* port */ 0, 33 /* isServer */ true, 34 numWorkers, 35 /* waitWorkers */ true, 36 /* timeout */ std::chrono::seconds(10)}; 37 38 store = c10::make_intrusive<c10d::TCPStore>(serverAddress, opts); 39 40 buildRpcAgent(); 41 42 rpcAgentPostProcessing(); 43 } 44 rpcAgentPostProcessing()45 void rpcAgentPostProcessing() { 46 RpcAgent::setCurrentRpcAgent(rpcAgent); 47 std::shared_ptr<TypeResolver> typeResolver = 48 std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) { 49 // For Dict that is used for device map. 50 auto pos = qn.name().find("Dict"); 51 if (pos != std::string::npos) { 52 return c10::StrongTypePtr( 53 nullptr, 54 c10::DictType::create( 55 c10::StringType::get(), c10::StringType::get())); 56 } 57 return c10::StrongTypePtr( 58 nullptr, c10::TensorType::create(at::Tensor())); 59 }); 60 rpcAgent->setTypeResolver(typeResolver); 61 rpcAgent->start(); 62 } 63 TearDown()64 void TearDown() override { 65 rpcAgent->join(); 66 rpcAgent->shutdown(); 67 RpcAgent::setCurrentRpcAgent(nullptr); 68 } 69 createRemoteRRef(at::Tensor t1,at::Tensor t2,std::shared_ptr<torch::jit::Operator> op)70 c10::intrusive_ptr<OwnerRRef> createRemoteRRef( 71 at::Tensor t1, 72 at::Tensor t2, 73 std::shared_ptr<torch::jit::Operator> op) { 74 auto& ctx = RRefContext::getInstance(); 75 auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1)); 76 // prevent this owner RRef being deleted due to other forks 77 ctx.addSelfAsFork(ownerRRef); 78 79 ScriptRemoteCall scriptRemoteCall( 80 op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId()); 81 auto jitFuture = autograd::sendMessageWithAutograd( 82 *rpcAgent, 83 rpcAgent->getWorkerInfo("worker"), 84 std::move(scriptRemoteCall).toMessage(), 85 false); 86 87 ownerRRef->registerOwnerCreationFuture(jitFuture); 88 89 // Builtin operators does not return py::object, and hence does not require 90 // GIL for destructing the potentially deleted OwerRRef. 91 jitFuture->addCallback( 92 [ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) { 93 callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId); 94 }); 95 return ownerRRef; 96 } 97 remoteAdd(at::Tensor t1,at::Tensor t2,std::shared_ptr<torch::jit::Operator> op)98 at::Tensor remoteAdd( 99 at::Tensor t1, 100 at::Tensor t2, 101 std::shared_ptr<torch::jit::Operator> op) { 102 ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1}); 103 104 // Send the RPC and return result. 105 auto response = autograd::sendMessageWithAutograd( 106 *rpcAgent, 107 rpcAgent->getWorkerInfo("worker"), 108 std::move(scriptCall).toMessage()); 109 response->waitAndThrow(); 110 111 MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP; 112 auto wrappedResponse = deserializeResponse( 113 std::move(*response->value().toCustomClass<Message>()), messageType); 114 return static_cast<ScriptResp&>(*wrappedResponse).value().toTensor(); 115 } 116 117 virtual void buildRpcAgent() = 0; 118 119 class AutogradContextGuard { 120 public: AutogradContextGuard()121 explicit AutogradContextGuard() 122 : context(DistAutogradContainer::getInstance().newContext()) {} 123 ~AutogradContextGuard()124 ~AutogradContextGuard() { 125 DistAutogradContainer::getInstance().releaseContext(context->contextId()); 126 } 127 128 private: 129 std::shared_ptr<DistAutogradContext> context; 130 }; 131 runTrainingLoop()132 void runTrainingLoop() { 133 auto options = at::TensorOptions().requires_grad(true); 134 auto t1 = torch::ones({3, 3}, options); 135 auto t2 = torch::ones({3, 3}, options); 136 137 c10::OperatorName full_name("aten::add", "Tensor"); 138 auto matchedOp = torch::jit::findOperatorFor(full_name); 139 ASSERT_TRUE(matchedOp); 140 141 for (size_t i = 0; i < numIters; i++) { 142 // Create the autograd context guard. 143 AutogradContextGuard guard; 144 145 // Multiple RPCs within one autograd context for the forward pass. 146 auto result = remoteAdd(t1, t2, matchedOp); 147 for (size_t j = 0; j < 5; j++) { 148 result = remoteAdd(t1, result, matchedOp); 149 } 150 151 auto rref = createRemoteRRef(t1, result, matchedOp); 152 result = rref->getValue().toTensor(); 153 154 // Run backward pass now. 155 autograd::DistEngine::getInstance().execute( 156 DistAutogradContainer::currentContextId(), 157 {torch::sum(result)}, 158 /* retainGraph */ false); 159 } 160 } 161 162 DistAutogradContainer* autogradContainer; 163 std::shared_ptr<RpcAgent> rpcAgent; 164 static const size_t numIters; 165 static const size_t numWorkers; 166 c10::intrusive_ptr<c10d::Store> store; 167 static const char* serverAddress; 168 }; 169 170 } // namespace rpc 171 } // namespace distributed 172 } // namespace torch 173