xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/context/container.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/autograd/context/container.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
5 
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9 
10 constexpr int kAutoIncrementBits = 48;
11 constexpr int64_t kAutoIncrementMask = (1LL << kAutoIncrementBits) - 1;
12 constexpr int kMaxWorkerId = 65535;
13 constexpr int kNumCleanupContextRetries = 20;
14 
15 constexpr int64_t kInvalidContextId = -1;
16 
17 // Each thread has a single autograd_context_id valid at any point in time.
18 static thread_local int64_t current_context_id_ = kInvalidContextId;
19 
20 // Lock to ensure DistAutogradContainer is initialized only once.
21 static std::mutex dist_container_init_lock_;
22 
DistAutogradContainer(uint32_t num_shards)23 DistAutogradContainer::DistAutogradContainer(uint32_t num_shards)
24     : next_context_id_(0),
25       worker_id_(0),
26       initialized_(false),
27       autograd_contexts_(num_shards),
28       num_shards_(num_shards),
29       next_autograd_message_id_(0),
30       max_id_(0) {
31   // num_shards has to be a power of 2 for the modulo trick in 'getShard'
32   // to work.
33   TORCH_INTERNAL_ASSERT((num_shards & (num_shards - 1)) == 0);
34 }
35 
init(int64_t worker_id)36 DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
37   std::lock_guard<std::mutex> guard(dist_container_init_lock_);
38 
39   TORCH_CHECK(
40       worker_id >= 0 && worker_id <= kMaxWorkerId,
41       "worker_id needs to be in the range [0, 65535]")
42 
43   auto& container = getInstanceInternal();
44   TORCH_CHECK(
45       !container.initialized_ || (worker_id == container.worker_id_),
46       "Container is already initialized with worker_id: ",
47       container.worker_id_,
48       ", cannot initialize with different worker_id: ",
49       worker_id);
50 
51   if (container.initialized_) {
52     LOG(INFO) << "DistAutogradContainer is already initialized";
53     return container;
54   }
55 
56   container.worker_id_ = worker_id;
57   container.next_context_id_ = static_cast<int64_t>(worker_id)
58       << kAutoIncrementBits;
59   container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
60       << kAutoIncrementBits;
61   container.max_id_ =
62       (kAutoIncrementMask |
63        (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
64   container.initialized_ = true;
65   return container;
66 }
67 
computeNumShards()68 uint32_t DistAutogradContainer::computeNumShards() {
69   uint32_t num_shards = 1;
70   auto num_hw_threads = std::thread::hardware_concurrency();
71   if (num_hw_threads == 0) {
72     num_shards = kNumDefaultShards;
73   } else {
74     // Compute the next power of 2 which is higher than twice the hardware
75     // concurrency.
76     while (num_shards < num_hw_threads * 2) {
77       num_shards <<= 1;
78     }
79   }
80   VLOG(1) << "Number of shards for DistAutogradContainer: " << num_shards;
81   return num_shards;
82 }
83 
getShard(int64_t context_id)84 inline DistAutogradContainer::ContextsShard& DistAutogradContainer::getShard(
85     int64_t context_id) {
86   // num_shards_ has to be a power of 2 for this modulo trick to work (validated
87   // during init).
88   return autograd_contexts_[context_id & (num_shards_ - 1)];
89 }
90 
getInstance()91 DistAutogradContainer& DistAutogradContainer::getInstance() {
92   auto& instance = getInstanceInternal();
93   TORCH_CHECK(
94       instance.initialized_,
95       "Need to initialize distributed autograd using "
96       "torch.distributed.autograd.init()");
97   return instance;
98 }
99 
getInstanceInternal()100 DistAutogradContainer& DistAutogradContainer::getInstanceInternal() {
101   // Leaky singleton to avoid module destructor race.
102   static DistAutogradContainer* container =
103       new DistAutogradContainer(computeNumShards());
104   return *container;
105 }
106 
newAutogradMessageId()107 int64_t DistAutogradContainer::newAutogradMessageId() {
108   // Check for overflow into workerId_ section.
109   TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
110   return next_autograd_message_id_++;
111 }
112 
getOrCreateContext(int64_t context_id)113 ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
114   auto& shard = getShard(context_id);
115   std::lock_guard<std::mutex> guard(shard.lock);
116   auto it = shard.contexts.find(context_id);
117   if (it != shard.contexts.end()) {
118     return it->second;
119   }
120 
121   auto& context =
122       shard.contexts
123           .emplace(
124               std::piecewise_construct,
125               std::forward_as_tuple(context_id),
126               std::forward_as_tuple(
127                   std::make_shared<DistAutogradContext>(context_id)))
128           .first->second;
129   return context;
130 }
131 
getWorkerId() const132 rpc::worker_id_t DistAutogradContainer::getWorkerId() const {
133   return worker_id_;
134 }
135 
newContext()136 const ContextPtr DistAutogradContainer::newContext() {
137   TORCH_CHECK(
138       current_context_id_ == kInvalidContextId,
139       "Already have an autograd context id for this thread.");
140 
141   auto context_id = next_context_id_++;
142   current_context_id_ = context_id;
143 
144   // Check for overflow into workerId_ section.
145   TORCH_INTERNAL_ASSERT(context_id < max_id_);
146 
147   auto& shard = getShard(context_id);
148   std::lock_guard<std::mutex> guard(shard.lock);
149   auto& context =
150       shard.contexts
151           .emplace(
152               std::piecewise_construct,
153               std::forward_as_tuple(context_id),
154               std::forward_as_tuple(
155                   std::make_shared<DistAutogradContext>(context_id)))
156           .first->second;
157 
158   return context;
159 }
160 
hasValidContext() const161 bool DistAutogradContainer::hasValidContext() const {
162   return current_context_id_ != kInvalidContextId;
163 }
164 
currentContext()165 ContextPtr DistAutogradContainer::currentContext() {
166   TORCH_CHECK(
167       hasValidContext(),
168       "Current thread doesn't have a valid autograd context. Please wrap your "
169       "code using: `with torch.distributed.autograd.context() as context_id` "
170       "to generate a valid context");
171 
172   auto& shard = getShard(current_context_id_);
173   std::lock_guard<std::mutex> guard(shard.lock);
174   auto it = shard.contexts.find(current_context_id_);
175   TORCH_CHECK(
176       it != shard.contexts.end(),
177       "Couldn't find autograd context "
178       "data for current autograd context id");
179   return it->second;
180 }
181 
releaseContextIfPresent(int64_t context_id)182 void DistAutogradContainer::releaseContextIfPresent(int64_t context_id) {
183   auto& shard = getShard(context_id);
184   std::unique_lock<std::mutex> lock(shard.lock);
185   auto it = shard.contexts.find(context_id);
186 
187   // no-op if the context does not exist on this thread. This could happen if an
188   // in-flight RPC has already released the context on this thread.
189   if (it == shard.contexts.end()) {
190     return;
191   }
192 
193   auto knownWorkerIds = it->second->getKnownWorkerIds();
194   eraseContextIdAndReset(shard, context_id);
195 
196   // Unlock since we no longer need the lock.
197   lock.unlock();
198   sendReleaseContextRpc(knownWorkerIds, context_id);
199 }
200 
releaseContext(int64_t context_id)201 void DistAutogradContainer::releaseContext(int64_t context_id) {
202   auto& shard = getShard(context_id);
203   std::unique_lock<std::mutex> lock(shard.lock);
204   auto it = shard.contexts.find(context_id);
205 
206   TORCH_CHECK(
207       it != shard.contexts.end(),
208       "Could not find autograd context with id: ",
209       context_id);
210 
211   auto knownWorkerIds = it->second->getKnownWorkerIds();
212   eraseContextIdAndReset(shard, context_id);
213 
214   // Unlock since we no longer need the lock.
215   lock.unlock();
216   sendReleaseContextRpc(knownWorkerIds, context_id);
217 }
218 
sendReleaseContextRpc(const std::unordered_set<rpc::worker_id_t> & workerIds,int64_t context_id)219 void DistAutogradContainer::sendReleaseContextRpc(
220     const std::unordered_set<rpc::worker_id_t>& workerIds,
221     int64_t context_id) {
222   // Best-effort notification to other workers to clean up their Dist autograd
223   // context, in order to reduce memory usage.
224   // agent.send() or getCurrentRpcAgent may throw an error in the case of an
225   // ungraceful shutdown, where we are shutting down RPC and also processing
226   // this message in a separate thread concurrently. In this case, don't throw
227   // here.
228   std::shared_ptr<rpc::RpcAgent> agent;
229   try {
230     agent = rpc::RpcAgent::getCurrentRpcAgent();
231   } catch (const std::exception& e) {
232     LOG(INFO)
233         << "Failed to send RPC to clear Dist Autograd context to all workers: "
234         << e.what();
235     return;
236   }
237 
238   TORCH_INTERNAL_ASSERT(agent, "RPC Agent should be set.");
239 
240   rpc::RpcRetryOptions options;
241   options.maxRetries = kNumCleanupContextRetries;
242   for (const auto& worker_id : workerIds) {
243     try {
244       auto cleanupFuture = agent->sendWithRetries(
245           agent->getWorkerInfo(worker_id),
246           CleanupAutogradContextReq(context_id).toMessage(),
247           options);
248 
249       cleanupFuture->addCallback([worker_id](rpc::JitFuture& future) {
250         if (future.hasError()) {
251           std::string errorMsg = c10::str(
252               "Could not release Dist Autograd Context on node ",
253               worker_id,
254               ": ",
255               future.tryRetrieveErrorMessage());
256           LOG(ERROR) << errorMsg;
257           return;
258         }
259       });
260     } catch (const std::exception& e) {
261       LOG(INFO)
262           << "Failed to send RPC to clear Dist Autograd context to worker id: "
263           << worker_id << " : " << e.what();
264     }
265   }
266 }
267 
eraseContextIdAndReset(DistAutogradContainer::ContextsShard & shard,int64_t context_id)268 void DistAutogradContainer::eraseContextIdAndReset(
269     DistAutogradContainer::ContextsShard& shard,
270     int64_t context_id) {
271   // We already have the shard lock here.
272   shard.contexts.erase(context_id);
273 
274   if (current_context_id_ == context_id) {
275     // Reset the thread_local current context id, since it is no longer valid.
276     current_context_id_ = kInvalidContextId;
277   }
278 }
279 
isValidContext(int64_t context_id)280 void DistAutogradContainer::isValidContext(int64_t context_id) {
281   auto& shard = getShard(context_id);
282   std::lock_guard<std::mutex> guard(shard.lock);
283   TORCH_CHECK(
284       shard.contexts.find(context_id) != shard.contexts.end(),
285       "Could not find autograd context with id: ",
286       context_id);
287 }
288 
retrieveContext(int64_t context_id)289 ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
290   auto& shard = getShard(context_id);
291   std::lock_guard<std::mutex> guard(shard.lock);
292   auto it = shard.contexts.find(context_id);
293   TORCH_CHECK(
294       it != shard.contexts.end(),
295       "Could not find autograd context with id: ",
296       context_id);
297   return it->second;
298 }
299 
getMaxId()300 int64_t DistAutogradContainer::getMaxId() {
301   return max_id_;
302 }
303 
forceCurrentContextId(int64_t contextId)304 void DistAutogradContainer::forceCurrentContextId(int64_t contextId) {
305   current_context_id_ = contextId;
306 }
307 
setCurrentContextId(int64_t contextId)308 void DistAutogradContainer::setCurrentContextId(int64_t contextId) {
309   TORCH_INTERNAL_ASSERT(
310       current_context_id_ == kInvalidContextId,
311       "Already have an autograd context id for this thread.");
312   current_context_id_ = contextId;
313 }
314 
clearCurrentContext()315 void DistAutogradContainer::clearCurrentContext() {
316   current_context_id_ = -1;
317 }
318 
numAutogradContexts() const319 size_t DistAutogradContainer::numAutogradContexts() const {
320   size_t ret = 0;
321   for (const auto& shard : autograd_contexts_) {
322     std::lock_guard<std::mutex> guard(shard.lock);
323     ret += shard.contexts.size();
324   }
325   return ret;
326 }
327 
currentContextId()328 int64_t DistAutogradContainer::currentContextId() {
329   return current_context_id_;
330 }
331 
332 } // namespace autograd
333 } // namespace distributed
334 } // namespace torch
335