1 #include <torch/csrc/distributed/autograd/context/container.h>
2
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
5
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9
10 constexpr int kAutoIncrementBits = 48;
11 constexpr int64_t kAutoIncrementMask = (1LL << kAutoIncrementBits) - 1;
12 constexpr int kMaxWorkerId = 65535;
13 constexpr int kNumCleanupContextRetries = 20;
14
15 constexpr int64_t kInvalidContextId = -1;
16
17 // Each thread has a single autograd_context_id valid at any point in time.
18 static thread_local int64_t current_context_id_ = kInvalidContextId;
19
20 // Lock to ensure DistAutogradContainer is initialized only once.
21 static std::mutex dist_container_init_lock_;
22
DistAutogradContainer(uint32_t num_shards)23 DistAutogradContainer::DistAutogradContainer(uint32_t num_shards)
24 : next_context_id_(0),
25 worker_id_(0),
26 initialized_(false),
27 autograd_contexts_(num_shards),
28 num_shards_(num_shards),
29 next_autograd_message_id_(0),
30 max_id_(0) {
31 // num_shards has to be a power of 2 for the modulo trick in 'getShard'
32 // to work.
33 TORCH_INTERNAL_ASSERT((num_shards & (num_shards - 1)) == 0);
34 }
35
init(int64_t worker_id)36 DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
37 std::lock_guard<std::mutex> guard(dist_container_init_lock_);
38
39 TORCH_CHECK(
40 worker_id >= 0 && worker_id <= kMaxWorkerId,
41 "worker_id needs to be in the range [0, 65535]")
42
43 auto& container = getInstanceInternal();
44 TORCH_CHECK(
45 !container.initialized_ || (worker_id == container.worker_id_),
46 "Container is already initialized with worker_id: ",
47 container.worker_id_,
48 ", cannot initialize with different worker_id: ",
49 worker_id);
50
51 if (container.initialized_) {
52 LOG(INFO) << "DistAutogradContainer is already initialized";
53 return container;
54 }
55
56 container.worker_id_ = worker_id;
57 container.next_context_id_ = static_cast<int64_t>(worker_id)
58 << kAutoIncrementBits;
59 container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
60 << kAutoIncrementBits;
61 container.max_id_ =
62 (kAutoIncrementMask |
63 (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
64 container.initialized_ = true;
65 return container;
66 }
67
computeNumShards()68 uint32_t DistAutogradContainer::computeNumShards() {
69 uint32_t num_shards = 1;
70 auto num_hw_threads = std::thread::hardware_concurrency();
71 if (num_hw_threads == 0) {
72 num_shards = kNumDefaultShards;
73 } else {
74 // Compute the next power of 2 which is higher than twice the hardware
75 // concurrency.
76 while (num_shards < num_hw_threads * 2) {
77 num_shards <<= 1;
78 }
79 }
80 VLOG(1) << "Number of shards for DistAutogradContainer: " << num_shards;
81 return num_shards;
82 }
83
getShard(int64_t context_id)84 inline DistAutogradContainer::ContextsShard& DistAutogradContainer::getShard(
85 int64_t context_id) {
86 // num_shards_ has to be a power of 2 for this modulo trick to work (validated
87 // during init).
88 return autograd_contexts_[context_id & (num_shards_ - 1)];
89 }
90
getInstance()91 DistAutogradContainer& DistAutogradContainer::getInstance() {
92 auto& instance = getInstanceInternal();
93 TORCH_CHECK(
94 instance.initialized_,
95 "Need to initialize distributed autograd using "
96 "torch.distributed.autograd.init()");
97 return instance;
98 }
99
getInstanceInternal()100 DistAutogradContainer& DistAutogradContainer::getInstanceInternal() {
101 // Leaky singleton to avoid module destructor race.
102 static DistAutogradContainer* container =
103 new DistAutogradContainer(computeNumShards());
104 return *container;
105 }
106
newAutogradMessageId()107 int64_t DistAutogradContainer::newAutogradMessageId() {
108 // Check for overflow into workerId_ section.
109 TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
110 return next_autograd_message_id_++;
111 }
112
getOrCreateContext(int64_t context_id)113 ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
114 auto& shard = getShard(context_id);
115 std::lock_guard<std::mutex> guard(shard.lock);
116 auto it = shard.contexts.find(context_id);
117 if (it != shard.contexts.end()) {
118 return it->second;
119 }
120
121 auto& context =
122 shard.contexts
123 .emplace(
124 std::piecewise_construct,
125 std::forward_as_tuple(context_id),
126 std::forward_as_tuple(
127 std::make_shared<DistAutogradContext>(context_id)))
128 .first->second;
129 return context;
130 }
131
getWorkerId() const132 rpc::worker_id_t DistAutogradContainer::getWorkerId() const {
133 return worker_id_;
134 }
135
newContext()136 const ContextPtr DistAutogradContainer::newContext() {
137 TORCH_CHECK(
138 current_context_id_ == kInvalidContextId,
139 "Already have an autograd context id for this thread.");
140
141 auto context_id = next_context_id_++;
142 current_context_id_ = context_id;
143
144 // Check for overflow into workerId_ section.
145 TORCH_INTERNAL_ASSERT(context_id < max_id_);
146
147 auto& shard = getShard(context_id);
148 std::lock_guard<std::mutex> guard(shard.lock);
149 auto& context =
150 shard.contexts
151 .emplace(
152 std::piecewise_construct,
153 std::forward_as_tuple(context_id),
154 std::forward_as_tuple(
155 std::make_shared<DistAutogradContext>(context_id)))
156 .first->second;
157
158 return context;
159 }
160
hasValidContext() const161 bool DistAutogradContainer::hasValidContext() const {
162 return current_context_id_ != kInvalidContextId;
163 }
164
currentContext()165 ContextPtr DistAutogradContainer::currentContext() {
166 TORCH_CHECK(
167 hasValidContext(),
168 "Current thread doesn't have a valid autograd context. Please wrap your "
169 "code using: `with torch.distributed.autograd.context() as context_id` "
170 "to generate a valid context");
171
172 auto& shard = getShard(current_context_id_);
173 std::lock_guard<std::mutex> guard(shard.lock);
174 auto it = shard.contexts.find(current_context_id_);
175 TORCH_CHECK(
176 it != shard.contexts.end(),
177 "Couldn't find autograd context "
178 "data for current autograd context id");
179 return it->second;
180 }
181
releaseContextIfPresent(int64_t context_id)182 void DistAutogradContainer::releaseContextIfPresent(int64_t context_id) {
183 auto& shard = getShard(context_id);
184 std::unique_lock<std::mutex> lock(shard.lock);
185 auto it = shard.contexts.find(context_id);
186
187 // no-op if the context does not exist on this thread. This could happen if an
188 // in-flight RPC has already released the context on this thread.
189 if (it == shard.contexts.end()) {
190 return;
191 }
192
193 auto knownWorkerIds = it->second->getKnownWorkerIds();
194 eraseContextIdAndReset(shard, context_id);
195
196 // Unlock since we no longer need the lock.
197 lock.unlock();
198 sendReleaseContextRpc(knownWorkerIds, context_id);
199 }
200
releaseContext(int64_t context_id)201 void DistAutogradContainer::releaseContext(int64_t context_id) {
202 auto& shard = getShard(context_id);
203 std::unique_lock<std::mutex> lock(shard.lock);
204 auto it = shard.contexts.find(context_id);
205
206 TORCH_CHECK(
207 it != shard.contexts.end(),
208 "Could not find autograd context with id: ",
209 context_id);
210
211 auto knownWorkerIds = it->second->getKnownWorkerIds();
212 eraseContextIdAndReset(shard, context_id);
213
214 // Unlock since we no longer need the lock.
215 lock.unlock();
216 sendReleaseContextRpc(knownWorkerIds, context_id);
217 }
218
sendReleaseContextRpc(const std::unordered_set<rpc::worker_id_t> & workerIds,int64_t context_id)219 void DistAutogradContainer::sendReleaseContextRpc(
220 const std::unordered_set<rpc::worker_id_t>& workerIds,
221 int64_t context_id) {
222 // Best-effort notification to other workers to clean up their Dist autograd
223 // context, in order to reduce memory usage.
224 // agent.send() or getCurrentRpcAgent may throw an error in the case of an
225 // ungraceful shutdown, where we are shutting down RPC and also processing
226 // this message in a separate thread concurrently. In this case, don't throw
227 // here.
228 std::shared_ptr<rpc::RpcAgent> agent;
229 try {
230 agent = rpc::RpcAgent::getCurrentRpcAgent();
231 } catch (const std::exception& e) {
232 LOG(INFO)
233 << "Failed to send RPC to clear Dist Autograd context to all workers: "
234 << e.what();
235 return;
236 }
237
238 TORCH_INTERNAL_ASSERT(agent, "RPC Agent should be set.");
239
240 rpc::RpcRetryOptions options;
241 options.maxRetries = kNumCleanupContextRetries;
242 for (const auto& worker_id : workerIds) {
243 try {
244 auto cleanupFuture = agent->sendWithRetries(
245 agent->getWorkerInfo(worker_id),
246 CleanupAutogradContextReq(context_id).toMessage(),
247 options);
248
249 cleanupFuture->addCallback([worker_id](rpc::JitFuture& future) {
250 if (future.hasError()) {
251 std::string errorMsg = c10::str(
252 "Could not release Dist Autograd Context on node ",
253 worker_id,
254 ": ",
255 future.tryRetrieveErrorMessage());
256 LOG(ERROR) << errorMsg;
257 return;
258 }
259 });
260 } catch (const std::exception& e) {
261 LOG(INFO)
262 << "Failed to send RPC to clear Dist Autograd context to worker id: "
263 << worker_id << " : " << e.what();
264 }
265 }
266 }
267
eraseContextIdAndReset(DistAutogradContainer::ContextsShard & shard,int64_t context_id)268 void DistAutogradContainer::eraseContextIdAndReset(
269 DistAutogradContainer::ContextsShard& shard,
270 int64_t context_id) {
271 // We already have the shard lock here.
272 shard.contexts.erase(context_id);
273
274 if (current_context_id_ == context_id) {
275 // Reset the thread_local current context id, since it is no longer valid.
276 current_context_id_ = kInvalidContextId;
277 }
278 }
279
isValidContext(int64_t context_id)280 void DistAutogradContainer::isValidContext(int64_t context_id) {
281 auto& shard = getShard(context_id);
282 std::lock_guard<std::mutex> guard(shard.lock);
283 TORCH_CHECK(
284 shard.contexts.find(context_id) != shard.contexts.end(),
285 "Could not find autograd context with id: ",
286 context_id);
287 }
288
retrieveContext(int64_t context_id)289 ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
290 auto& shard = getShard(context_id);
291 std::lock_guard<std::mutex> guard(shard.lock);
292 auto it = shard.contexts.find(context_id);
293 TORCH_CHECK(
294 it != shard.contexts.end(),
295 "Could not find autograd context with id: ",
296 context_id);
297 return it->second;
298 }
299
getMaxId()300 int64_t DistAutogradContainer::getMaxId() {
301 return max_id_;
302 }
303
forceCurrentContextId(int64_t contextId)304 void DistAutogradContainer::forceCurrentContextId(int64_t contextId) {
305 current_context_id_ = contextId;
306 }
307
setCurrentContextId(int64_t contextId)308 void DistAutogradContainer::setCurrentContextId(int64_t contextId) {
309 TORCH_INTERNAL_ASSERT(
310 current_context_id_ == kInvalidContextId,
311 "Already have an autograd context id for this thread.");
312 current_context_id_ = contextId;
313 }
314
clearCurrentContext()315 void DistAutogradContainer::clearCurrentContext() {
316 current_context_id_ = -1;
317 }
318
numAutogradContexts() const319 size_t DistAutogradContainer::numAutogradContexts() const {
320 size_t ret = 0;
321 for (const auto& shard : autograd_contexts_) {
322 std::lock_guard<std::mutex> guard(shard.lock);
323 ret += shard.contexts.size();
324 }
325 return ret;
326 }
327
currentContextId()328 int64_t DistAutogradContainer::currentContextId() {
329 return current_context_id_;
330 }
331
332 } // namespace autograd
333 } // namespace distributed
334 } // namespace torch
335