1 #include <torch/csrc/autograd/engine.h>
2
3 #include <torch/csrc/autograd/anomaly_mode.h>
4 #include <torch/csrc/autograd/autograd.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/autograd/functions/basic_ops.h>
7 #include <torch/csrc/autograd/grad_mode.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/dynamo/compiled_autograd.h>
10
11 #include <ATen/DeviceAccelerator.h>
12 #include <ATen/DeviceGuard.h>
13 #include <ATen/ExpandUtils.h>
14 #include <ATen/Parallel.h>
15 #include <ATen/SparseCsrTensorUtils.h>
16 #include <ATen/detail/CUDAHooksInterface.h>
17 #include <ATen/detail/PrivateUse1HooksInterface.h>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 #include <ATen/ops/isnan.h>
23 #endif
24
25 #include <c10/core/DeviceGuard.h>
26 #include <c10/core/Event.h>
27 #include <c10/core/Stream.h>
28 #include <c10/core/StreamGuard.h>
29 #include <c10/util/AbortHandler.h>
30 #include <c10/util/Exception.h>
31 #include <c10/util/ThreadLocal.h>
32 #include <c10/util/irange.h>
33 #include <c10/util/thread_name.h>
34
35 #include <atomic>
36 #include <chrono>
37 #include <cstdint>
38 #include <functional>
39 #include <memory>
40 #include <mutex>
41 #include <optional>
42 #include <string>
43 #include <thread>
44 #include <unordered_set>
45 #include <utility>
46
47 namespace torch::autograd {
48
49 namespace {
50 static bool in_bad_autograd_fork =
51 false; // True for children forked after engine's thread pool init
52
53 // Called in the forked child if engine's thread pool has already been
54 // initialized
forked_autograd_child()55 static void forked_autograd_child() {
56 in_bad_autograd_fork = true;
57 }
58
59 // Should be called before unsafe for forks (thread pool) calls
track_bad_autograd_forks()60 static void track_bad_autograd_forks() {
61 #if !defined(WIN32)
62 static c10::once_flag flag;
63 c10::call_once(
64 flag, [&] { pthread_atfork(nullptr, nullptr, forked_autograd_child); });
65 #endif
66 }
67
should_run_in_cpu_ready_queue(c10::DeviceType device)68 inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) {
69 if (device == c10::kCPU || device == c10::kMeta || device == c10::kLazy) {
70 return true;
71 } else {
72 return false;
73 }
74 }
75
76 std::atomic<Engine::compiled_autograd_fn> the_compiled_autograd = nullptr;
77 #define COMPILED_AUTOGRAD_POISON \
78 reinterpret_cast<Engine::compiled_autograd_fn>(1)
79 std::atomic<int32_t> num_threads_in_backwards;
80 struct CompiledAutogradThreadingDebugCheck {
CompiledAutogradThreadingDebugChecktorch::autograd::__anon6c5fbed50111::CompiledAutogradThreadingDebugCheck81 CompiledAutogradThreadingDebugCheck() {
82 num_threads_in_backwards++;
83 }
~CompiledAutogradThreadingDebugChecktorch::autograd::__anon6c5fbed50111::CompiledAutogradThreadingDebugCheck84 ~CompiledAutogradThreadingDebugCheck() {
85 release();
86 }
releasetorch::autograd::__anon6c5fbed50111::CompiledAutogradThreadingDebugCheck87 void release() {
88 if (std::exchange(incremented, false)) {
89 num_threads_in_backwards--;
90 }
91 }
92
93 private:
94 bool incremented{true};
95 };
96
97 } // namespace
98
99 // Threads spawned by the engine are assigned a 'worker_device' specifying
100 // what device they process work for. This variable is initialized at:
101 // 1. thread creation time for CUDA, XLA device threads, as they are
102 // spinning threads waiting for works on their device.
103 // 2. before the graph task execution for CPU threads, as for each
104 // backward call we use the caller thread to drive engine execution.
105 // This is used when handling reentrant backwards calls;
106 // See Note [Reentrant backwards]
107 static thread_local int worker_device = NO_DEVICE;
108
109 // This variable is true if ALL invocations in the stack of re-entrant engine
110 // invocations are imperative backwards. This special variable is needed for the
111 // gradient checkpointing feature only.
112 static thread_local bool checkpoint_valid = true;
113
114 // Number of nested reentrant backwards calls currently on this thread
115 static thread_local int current_depth = 0;
116
117 // For all device threads (i.e. CUDA, XLA), total_depth represents the total
118 // nested
119 // reentrant backwards depths over all device threads.
120 // For CPU devices, it is the total depth associated with the original backward
121 // call.
122 static thread_local int total_depth = 0;
123
124 // The current GraphTask being executed by this thread. This helps
125 // queue_callback() to find the target GraphTask to append final callbacks.
126 C10_DEFINE_TLS_static(std::shared_ptr<GraphTask>, tls_current_graph_task);
127 #define current_graph_task (tls_current_graph_task.get())
128
129 // Every autograd worker thread is associated with a ready queue, which
130 // specifies the stream of work of this thread to do. This shared_ptr is a
131 // thread_local pointer to each thread's ready_queue, and it should be
132 // initialized via the Engine::init_local_ready_queue() call in each
133 // corresponding thread before execution.
134 //
135 // The CUDA, XLA threads are shared among all invocations of backwards via
136 // device_ready_queues_, while the caller thread is dedicated to processing work
137 // for devices returning true in should_run_in_cpu_ready_queue (most notably the
138 // CPU device). So any given graph task maintains its own cpu_ready_queue_ where
139 // you should send work for it to be done.
140 //
141 // For reentrant backward calls, if we spawn new thread from the current thread
142 // because we reached the maximum depth, the new thread will just reuse the same
143 // ReadyQueue with the parent thread for performance improvement.
144 // see Note [Reentrant backwards] for more details.
145 C10_DEFINE_TLS_static(std::shared_ptr<ReadyQueue>, tls_local_ready_queue);
146 #define local_ready_queue (tls_local_ready_queue.get())
147
148 // Note [Reentrant backwards]
149 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
150 // To understand the reentrant backwards problem, we have to notice two
151 // aspects of how the autograd engine is implemented today:
152 //
153 // 1. When you call Engine::execute(), you want to block until
154 // differentiation finishes so that you can get the final result variables
155 // of the backwards pass.
156 //
157 // 2. The engine operates by having a single worker thread per work queue,
158 // and every work queue is pinned to a specific device where the
159 // operation is executed.
160 //
161 // The problem is, suppose that you call backward() inside of a worker
162 // thread. By property (1), we're supposed to block until the nested task
163 // finishes. However, by property (2), this worker thread is on the
164 // hook for processing the tasks assigned to it; we better not block,
165 // because then all of our backward executions (including the one we
166 // just started) will deadlock!
167 //
168 // We maintain a pool of threads waiting for work to do
169 // When a reentrant backwards call occurs, the current thread blocks
170 // and a thread from the pool is woken up to complete the blocking tasks and an
171 // any other tasks that would have been assigned to that worker. If there are no
172 // threads available, a new thread is spawned. The new thread will continue
173 // processing tasks from the same ReadyQueue as the parent worker
174 //
175 // When the GraphTask is finished, the parent worker thread that is waiting on
176 // the task is notified and the current thread returns to the pool.
177
178 // Note [Streaming backwards]
179 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
180 // On CUDA/privateuse1 devices the autograd engine's device operations are run
181 // on the same stream that ran them in forward. This requires automatically
182 // syncing the streams so that function A finishes producing its
183 // output before function B consumes it.
184 //
185 // This synchronization occurs when outputs are placed into input buffers.
186 // The functions corresponding to input buffer positions have metadata
187 // recording their streams from forward, and during backward this
188 // data is used to sync the producer's stream with the consumer's.
189 //
190 // When a CUDA/privateuse1 function is run either all its inputs were
191 // accumulated on the stream used to run the function OR the inputs are on
192 // different devices and the function is responsible for properly acquiring
193 // them.
194 //
195 // User-facing stream semantics of a backward() (or torch.autograd.grad())
196 // call with respect to surrounding ops are the same as for any other call.
197 // See "Stream semantics of backward passes" on
198 // https://pytorch.org/docs/stable/notes/cuda.html
199 //
200 // Internally, backward() runs ops (including leaf nodes) on side threads.
201 // And streams are thread local. So GraphTask achieves the above semantics by
202 // 1. remembering the current streams on all active CUDA/privateuse1 devices
203 // in the user-facing thread (aka, the thread that called execute() to
204 // launch the GraphTask)
205 // 2. remembering the "leaf streams" (streams each backward leaf node ran on)
206 // 3. during exec_post_processing, for each leaf stream, sync the remembered
207 // current streams (on the leaf stream's device) with that
208 // leaf stream.
209
getReentrantDepth() const210 int NodeTask::getReentrantDepth() const {
211 std::shared_ptr<GraphTask> graph_task = base_.lock();
212 if (graph_task) {
213 return graph_task->reentrant_depth_;
214 } else {
215 // The graph task is no longer valid indicating an error. As a result, we
216 // try to move this to the front of the queue to ensure the autograd
217 // engine threads pick up this error soon.
218 return std::numeric_limits<int>::max();
219 }
220 }
221
CheckpointValidGuard(const std::shared_ptr<const GraphTask> & graph_task)222 CheckpointValidGuard::CheckpointValidGuard(
223 const std::shared_ptr<const GraphTask>& graph_task)
224 : prev_checkpoint_valid_state(checkpoint_valid) {
225 checkpoint_valid =
226 graph_task->can_checkpoint() && prev_checkpoint_valid_state;
227 }
228
~CheckpointValidGuard()229 CheckpointValidGuard::~CheckpointValidGuard() {
230 checkpoint_valid = prev_checkpoint_valid_state;
231 }
232
push(NodeTask item,bool incrementOutstandingTasks)233 auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
234 {
235 // Lock mutex for writing to heap_
236 std::lock_guard<std::mutex> lock(mutex_);
237 if (incrementOutstandingTasks) {
238 std::shared_ptr<GraphTask> graph_task = item.base_.lock();
239 TORCH_INTERNAL_ASSERT(graph_task, "GraphTask is no longer valid!");
240 ++graph_task->outstanding_tasks_;
241 }
242 heap_.push(std::move(item));
243 }
244 not_empty_.notify_one();
245 }
246
pushShutdownTask()247 auto ReadyQueue::pushShutdownTask() -> void {
248 {
249 std::lock_guard<std::mutex> lock(mutex_);
250 heap_.push(NodeTask({}, nullptr, InputBuffer(0), true));
251 }
252 not_empty_.notify_one();
253 }
254
size() const255 size_t ReadyQueue::size() const {
256 // Lock mutex for accesses to heap_
257 std::unique_lock<std::mutex> lock(mutex_);
258 return heap_.size();
259 }
260
pop()261 auto ReadyQueue::pop() -> NodeTask {
262 // Lock mutex for accesses to heap_
263 std::unique_lock<std::mutex> lock(mutex_);
264 not_empty_.wait(lock, [this] { return !heap_.empty(); });
265 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
266 auto task = std::move(const_cast<NodeTask&>(heap_.top()));
267 heap_.pop();
268 return task;
269 }
270
empty() const271 bool ReadyQueue::empty() const {
272 // Lock mutex for accesses to heap_
273 std::unique_lock<std::mutex> lock(mutex_);
274 return heap_.empty();
275 }
276
Engine()277 Engine::Engine()
278 : max_recursion_depth_(MAX_DEPTH), non_reentrant_device_thread_count_(0) {}
279
~Engine()280 Engine::~Engine() {
281 stop();
282 }
283
284 // Send shutdown tasks to all device_ready_queues_ if no backward tasks are
285 // running Even though readyQueue should be empty, shutdown tasks have the
286 // highest priority
stop()287 void Engine::stop() {
288 if (stopped_) {
289 return;
290 }
291 stopped_ = true;
292 // Under some conditions, autograd threads can hang on shutdown
293 // Do not wait for them to shutdown indefinitely but rely on timeout
294 auto wait_duration_str = getenv("TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT");
295 auto wait_duration = wait_duration_str ? std::atof(wait_duration_str) : 10.0;
296 bool noBackward = true;
297 for (auto& queue : device_ready_queues_) {
298 noBackward = noBackward && queue->empty();
299 }
300 if (noBackward && wait_duration > 0.0f) {
301 for (auto& queue : device_ready_queues_) {
302 queue->pushShutdownTask();
303 }
304 // Do not wait for termination of global threads on Windows
305 // Because CRT terminates DLL threads before calling
306 // global object destructors
307 #if !defined(_WIN32) || defined(C10_USE_MSVC_STATIC_RUNTIME)
308
309 using namespace std::chrono_literals;
310 // Set a deadline for how long it is OK to wait device threads to shutdown
311 auto wait_deadline =
312 std::chrono::steady_clock::now() + wait_duration * 1.0s;
313 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
314 while (non_reentrant_device_thread_count_.load() != 0) {
315 if (non_reentrant_device_thread_condvar_.wait_until(lk, wait_deadline) ==
316 std::cv_status::timeout) {
317 break;
318 }
319 }
320 #endif
321 }
322 // Otherwise threads are leaked
323 }
324
release_workers()325 void Engine::release_workers() {
326 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
327 non_reentrant_device_thread_count_.store(0);
328 non_reentrant_device_thread_condvar_.notify_one();
329 }
330
increment_non_reentrant_thread_count()331 void Engine::increment_non_reentrant_thread_count() {
332 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
333 non_reentrant_device_thread_count_.fetch_add(1);
334 non_reentrant_device_thread_condvar_.notify_one();
335 }
336
decrement_non_reentrant_thread_count()337 void Engine::decrement_non_reentrant_thread_count() {
338 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
339 non_reentrant_device_thread_count_.fetch_sub(1);
340 non_reentrant_device_thread_condvar_.notify_one();
341 }
342
thread_init(int device,const std::shared_ptr<ReadyQueue> & ready_queue,bool should_increment)343 void Engine::thread_init(
344 int device,
345 const std::shared_ptr<ReadyQueue>& ready_queue,
346 bool should_increment) {
347 // pthread_setname_np restricts the name to 16 characters including
348 // the null byte.
349 std::string thread_name = "pt_autograd_" + std::to_string(device);
350 c10::setThreadName(thread_name);
351
352 c10::set_terminate_handler();
353 if (should_increment) {
354 increment_non_reentrant_thread_count();
355 }
356
357 at::init_num_threads();
358
359 // Note [Allocating GPUs to autograd threads]
360 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
361 // What's our strategy here? Originally, the autograd engine was written
362 // with only CUDA in mind. We allocate one thread to handle all CPU
363 // operations, and a thread per CUDA device.
364 //
365 // But what if we have OTHER devices? There are two plausible
366 // strategies:
367 //
368 // - We can allocate threads equal to max(num_cuda_devices, num_xla_devices,
369 // ...) and colocate cuda device 0 with xla device 0
370 // - We can allocate threads equal to sum(num_cuda_devices, num_xla_devices,
371 // ...) keeping everyone separate.
372 //
373 // We don't have any good reason to prefer one or the other, so we've
374 // arbitrarily picked to colocate devices. Maybe the other approach is
375 // better.
376 worker_device = device;
377
378 // initialize each device thread's thread local ready queue with the ready
379 // queue that is created before the thread initialization
380 init_local_ready_queue(ready_queue);
381
382 std::shared_ptr<GraphTask> graph_task = nullptr;
383 thread_main(graph_task);
384 if (should_increment) {
385 // Decrement the count during shutdown if we incremented earlier.
386 decrement_non_reentrant_thread_count();
387 }
388 }
389
GraphTaskGuard(std::shared_ptr<GraphTask> graph_task)390 GraphTaskGuard::GraphTaskGuard(std::shared_ptr<GraphTask> graph_task)
391 : last_graph_task_(std::move(current_graph_task)) {
392 current_graph_task = std::move(graph_task);
393 }
~GraphTaskGuard()394 GraphTaskGuard::~GraphTaskGuard() {
395 restore_current_graph_task();
396 }
397
restore_current_graph_task()398 void GraphTaskGuard::restore_current_graph_task() {
399 current_graph_task = std::move(last_graph_task_);
400 }
401
402 // The current graph task's exec_info is being used to trim unnecessary edegs
403 // during node evaluation, see `Node.task_should_compute_output()` function.
404 const std::unordered_map<Node*, GraphTask::ExecInfo>*
get_current_graph_task_exec_info()405 get_current_graph_task_exec_info() {
406 return current_graph_task ? ¤t_graph_task->exec_info_ : nullptr;
407 }
408
get_current_graph_task_nodes_in_graph()409 const std::unordered_set<Node*>* get_current_graph_task_nodes_in_graph() {
410 return current_graph_task ? ¤t_graph_task->nodes_in_graph_ : nullptr;
411 }
412
get_current_graph_task_id()413 int get_current_graph_task_id() {
414 return current_graph_task ? current_graph_task->id_ : -1;
415 }
416
get_current_graph_task_keep_graph()417 bool get_current_graph_task_keep_graph() {
418 return current_graph_task ? current_graph_task->keep_graph_ : true;
419 }
420
add_node_to_current_graph_task_exec_info(Node * fn)421 void add_node_to_current_graph_task_exec_info(Node* fn) {
422 current_graph_task->exec_info_[fn].needed_ = true;
423 }
424
425 // NB: The engine itself does not use the outputs of this function.
get_current_graph_task_execution_order()426 std::vector<Node*> get_current_graph_task_execution_order() {
427 std::shared_ptr<GraphTask> task = current_graph_task;
428 if (!task) {
429 return {};
430 }
431
432 // We could potentially check if there is only a single device here
433 // but explicitly require this context doesn't seem bad either
434 TORCH_CHECK(
435 !c10::AutogradState::get_tls_state().get_multithreading_enabled(),
436 "get_current_graph_task_execution_order expects the current backward to be "
437 "executed with multithreading disabled, e.g. by running:\n\n"
438 ">>> with torch.autograd.set_multithreading_enabled(False):\n"
439 "... torch.autograd.grad(...)\n");
440
441 const bool check_exec_info = !task->exec_info_.empty();
442 std::vector<Node*> out{};
443 // Do a copy since we mutate it later
444 std::unordered_map<Node*, int> dependencies = task->dependencies_;
445
446 auto compare_seq_nr = [](Node* n1, Node* n2) {
447 return n1->sequence_nr() < n2->sequence_nr();
448 };
449 std::priority_queue<Node*, std::vector<Node*>, decltype(compare_seq_nr)> heap(
450 compare_seq_nr);
451
452 for (Node* ptr : task->graph_roots_) {
453 heap.push(ptr);
454 }
455
456 // Implementation notes:
457 // - We need count dependencies even though we have sequence_nr, because
458 // in the accumulate_grad case we cannot assume the outputs to have higher
459 // sequence_nr than the inputs
460 // - Don't need to check topological_nr because we have exec_info
461 while (!heap.empty()) {
462 Node* fn = heap.top();
463 heap.pop();
464
465 out.push_back(fn);
466 for (const auto& edge : fn->next_edges()) {
467 Node* next_ptr = edge.function.get();
468 if (!next_ptr) {
469 continue;
470 }
471 if (check_exec_info) {
472 auto it = task->exec_info_.find(next_ptr);
473 if (it == task->exec_info_.end() || !it->second.should_execute()) {
474 continue;
475 }
476 }
477 auto it = dependencies.find(edge.function.get());
478 TORCH_INTERNAL_ASSERT(it != dependencies.end());
479 if (--it->second == 0) {
480 dependencies.erase(it);
481 heap.push(next_ptr);
482 }
483 }
484 }
485 return out;
486 }
487
488 // NOTE: graph_tasks do not necessarily form a stack. Imagine this
489 // case:
490 //
491 // +----> Eval1
492 // Root
493 // +----> Eval2
494 //
495 // Once Root is executed, both Eval1 and Eval2 are added to the ready queue.
496 // Next, Eval1 is run and this causes the worker to enter thread_main again.
497 // Then, it pops the next task from the queue, but at this point it is Eval2.
498 // It enters thread_main once again, but now with graph_task of Eval2, which is
499 // completely unrelated to that of Eval1 (it's not a recursive call).
500 // It's all ok and is handled right now, but it should be accounted for
501 // in case this code is to be changed.
502 //
503 // thread_main is used by:
504 // 1). autograd threads for devices (i.e. CUDA, XLA)
505 // 2). the caller/owning thread of the backward call on CPU (sync mode)
506 // 3). Renetrant backward that invoked by either 1) or 2)
507 // The exit conditions are different for the above three cases.
508 // For 1), we are spinning on running the thread_main on device autograd
509 // threads throughout the Engine lifetime, thread_main will get
510 // terminated during Engine destruction by pushing shutdown tasks
511 // For 2), the owning thread of the backward call drives the thread_main
512 // synchronously until the graph_task of that owning thread is
513 // completed and exit the thread_main to continue executing the
514 // result of caller's code.
515 // For 3), the reentrant backward that invokes
516 // thread_main, either from 1) or 2), will not spin and will exit as
517 // long as graph_task is completed and notify the owning thread as
518 // needed.
thread_main(const std::shared_ptr<GraphTask> & graph_task)519 auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
520 // When graph_task is nullptr, this is a long running thread that processes
521 // tasks (ex: device threads). When graph_task is non-null (ex: reentrant
522 // backwards, user thread), this function is expected to exit once that
523 // graph_task complete.
524
525 // local_ready_queue should already been initialized when we get into
526 // thread_main
527 TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr);
528 while (graph_task == nullptr || !graph_task->future_result_->completed()) {
529 // local_graph_task represents the graph_task we retrieve from the queue.
530 // The outer graph_task represents the overall graph_task we need to execute
531 // for reentrant execution.
532 std::shared_ptr<GraphTask> local_graph_task;
533 {
534 // Scope this block of execution since NodeTask is not needed after this
535 // block and can be deallocated (release any references to grad tensors
536 // as part of inputs_).
537 NodeTask task = local_ready_queue->pop();
538 // This will only work if the worker is running a non backward task
539 // TODO Needs to be fixed this to work in all cases
540 if (task.isShutdownTask_) {
541 C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
542 break;
543 }
544
545 local_graph_task = task.base_.lock();
546 if (!local_graph_task) {
547 // GraphTask for function is no longer valid, skipping further
548 // execution.
549 continue;
550 }
551
552 set_device(worker_device);
553
554 if (task.fn_ && !local_graph_task->has_error_.load()) {
555 // Set the ThreadLocalState before calling the function.
556 // NB: The ThreadLocalStateGuard doesn't set the grad_mode because
557 // GraphTask always saves ThreadLocalState without grad_mode.
558 at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
559 c10::WarningUtils::WarningHandlerGuard warnings_guard(
560 &local_graph_task->warning_handler_);
561
562 try {
563 // The guard sets the thread_local current_graph_task on construction
564 // and restores it on exit. The current_graph_task variable helps
565 // queue_callback() to find the target GraphTask to append final
566 // callbacks.
567 GraphTaskGuard guard(local_graph_task);
568 NodeGuard ndguard(task.fn_);
569 {
570 RECORD_FUNCTION(
571 c10::str(
572 "autograd::engine::evaluate_function: ",
573 task.fn_.get()->name()),
574 c10::ArrayRef<const c10::IValue>());
575 evaluate_function(
576 local_graph_task,
577 task.fn_.get(),
578 task.inputs_,
579 local_graph_task->cpu_ready_queue_);
580 }
581 } catch (std::exception& e) {
582 // See Note [ Persisting PyErr state across autograd engine threads ]
583 thread_on_exception(local_graph_task, task.fn_, e);
584 }
585 }
586 }
587
588 // Decrement the outstanding tasks.
589 --local_graph_task->outstanding_tasks_;
590
591 // Check if we've completed execution.
592 if (local_graph_task->completed()) {
593 local_graph_task->mark_as_completed_and_run_post_processing();
594
595 auto base_owner = local_graph_task->owner_;
596 // The current worker thread finish the graph_task, but the owning thread
597 // of the graph_task might be sleeping on pop() if it does not have work.
598 // So we need to send a dummy function task to the owning thread just to
599 // ensure that it's not sleeping, so that we can exit the thread_main.
600 // If it has work, it might see that graph_task->outstanding_tasks_ == 0
601 // before it gets to the task, but it's a no-op anyway.
602 //
603 // NB: This is not necessary if the current thread is the owning thread.
604 if (worker_device != base_owner) {
605 // Synchronize outstanding_tasks_ with queue mutex
606 std::atomic_thread_fence(std::memory_order_release);
607 ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
608 ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
609 }
610 }
611 }
612 }
613
614 // Reentrant call will re-use the graph_task's owner thread ready_queue for
615 // queueing tasks (NOTE: this is not true in the async_mode of the engine).
616 // While we can create separate ready queue for each new reentrant
617 // thread, but sharing the same cpu_ready_queue with parent thread is a
618 // performance improvement and cuda thread still have to do the same thing.
reentrant_thread_init()619 void Engine::reentrant_thread_init() {
620 c10::set_terminate_handler();
621 at::init_num_threads();
622 auto tp_shared = thread_pool_shared_;
623 while (true) {
624 std::unique_lock<std::mutex> lk(tp_shared->mutex_);
625 ++thread_pool_shared_->num_workers_;
626 tp_shared->work_.wait(
627 lk, [&tp_shared] { return !tp_shared->graphtasks_queue_.empty(); });
628 --thread_pool_shared_->num_workers_;
629 auto task = tp_shared->graphtasks_queue_.front();
630 tp_shared->graphtasks_queue_.pop();
631 lk.unlock();
632 std::shared_ptr<GraphTask> graph_task = task.lock();
633 if (!graph_task) {
634 LOG(INFO) << "GraphTask has expired, skipping reentrant execution";
635 continue;
636 }
637 set_device(graph_task->owner_);
638 // set the local_ready_queue to the ready queue on the graph_task->owner_
639 // device
640 local_ready_queue =
641 ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_);
642 total_depth = graph_task->reentrant_depth_;
643 thread_main(graph_task);
644 }
645 }
646
thread_on_exception(const std::shared_ptr<GraphTask> & graph_task,const std::shared_ptr<Node> & fn,std::exception & e)647 void Engine::thread_on_exception(
648 const std::shared_ptr<GraphTask>& graph_task,
649 const std::shared_ptr<Node>& fn,
650 std::exception& e) {
651 graph_task->set_exception(std::current_exception(), fn);
652 }
653
654 namespace {
655 std::atomic<uint64_t> graph_task_id{0};
656 }
657
GraphTask(bool keep_graph,bool grad_mode,int reentrant_depth,std::shared_ptr<ReadyQueue> cpu_ready_queue,c10::SmallVector<Node *,4> graph_roots,bool exit_on_error)658 GraphTask::GraphTask(
659 bool keep_graph,
660 bool grad_mode,
661 int reentrant_depth,
662 std::shared_ptr<ReadyQueue> cpu_ready_queue,
663 c10::SmallVector<Node*, 4> graph_roots,
664 bool exit_on_error)
665 : keep_graph_(keep_graph),
666 graph_roots_(std::move(graph_roots)),
667 owner_(NO_DEVICE),
668 reentrant_depth_(reentrant_depth),
669 exit_on_error_(exit_on_error),
670 cpu_ready_queue_(std::move(cpu_ready_queue)),
671 future_result_(c10::make_intrusive<at::ivalue::Future>(
672 c10::ListType::create(c10::TensorType::get()))),
673 id_(graph_task_id.fetch_add(1, std::memory_order_relaxed)) {
674 thread_locals_.set_grad_mode(grad_mode);
675 }
676
completed()677 bool GraphTask::completed() {
678 return outstanding_tasks_.load() == 0 ||
679 (exit_on_error_ && has_error_.load());
680 }
681
mark_as_completed_and_run_post_processing()682 void GraphTask::mark_as_completed_and_run_post_processing() {
683 // Allow only one thread one attempt to process this logic.
684 if (future_completed_.exchange(true)) {
685 // Future is already marked complete, or being marked as such.
686 // In case the marking complete is only in progress, we add a
687 // wait() to guarantee the future is marked complete on exit.
688 future_result_->wait();
689 return;
690 }
691
692 try {
693 // Run post processing, before marking the future as complete.
694 // Drop lock prior to completing, to avoid holding across callbacks.
695 std::unique_lock<std::mutex> lock(mutex_);
696
697 exec_post_processing();
698 std::vector<Variable> vars = std::move(captured_vars_);
699
700 // Need to unlock before we call markCompleted to avoid holding locks
701 // when the callbacks are called.
702 lock.unlock();
703 future_result_->markCompleted(vars);
704 } catch (std::exception&) {
705 future_result_->setErrorIfNeeded(std::current_exception());
706 }
707 }
708
exec_post_processing()709 void GraphTask::exec_post_processing() {
710 if (!not_ready_.empty()) {
711 throw std::runtime_error("could not compute gradients for some functions");
712 }
713
714 // set the thread_local current_graph_task_ as more callbacks can be installed
715 // by existing final callbacks.
716 GraphTaskGuard guard(shared_from_this());
717 // Lock mutex during each iteration for accessing final_callbacks.size()
718 // Unlocking is necessary, because the callback can register
719 // more callbacks (or they can be registered from other threads
720 // while it's waiting.
721 std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
722
723 // caller_current_streams_ with nullopt entries removed
724 std::vector<c10::Stream> caller_current_streams_filtered;
725
726 // See Note [Streaming backwards].
727 // Syncs caller_current_stream with leaf streams, so final_callbacks may use
728 // any grad on its device's current stream.
729 if (!leaf_streams.empty()) {
730 for (const auto& leaf_stream : leaf_streams) {
731 // stash_current_cuda/privateuse1_streams() stashed streams for all device
732 // IDs that already had a CUDA/privateuse1 context before the GraphTask
733 // executed. For inactive devices, it stashed a std::nullopt. I don't
734 // expect GraphTask's backward pass ran leaf nodes on any new devices, so
735 // the stashed streams should be enough. If leaf_stream.device_index()
736 // happens to be for a new device, operator* on the std::nullopt should
737 // throw an error.
738 const auto caller_current_stream =
739 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
740 *caller_current_streams_[leaf_stream.device_index()];
741
742 if (caller_current_stream != leaf_stream) {
743 auto event = c10::Event{leaf_stream.device_type()};
744 event.record(leaf_stream);
745 caller_current_stream.wait(event);
746 }
747 }
748
749 caller_current_streams_filtered.reserve(caller_current_streams_.size());
750 for (const auto& opt_stream : caller_current_streams_) {
751 if (opt_stream.has_value()) {
752 caller_current_streams_filtered.push_back(*opt_stream);
753 }
754 }
755 }
756
757 {
758 // final_callbacks run on the per-device caller_current_streams (the ambient
759 // streams surrounding the user's call to backward()). This has two
760 // benefits:
761 // 1. caller_current_streams have been synced with leaf_streams, so
762 // callbacks may
763 // safely access any grad.
764 // 2. The callback's results can safely be used on (user-facing)
765 // caller_current_streams
766 // after backward().
767 c10::MultiStreamGuard g(caller_current_streams_filtered);
768
769 // Set the ThreadLocalState before calling the function.
770 // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
771 // always saves ThreadLocalState without grad_mode.
772 at::ThreadLocalStateGuard tls_guard(this->thread_locals_);
773
774 // WARNING: Don't use a range-for loop here because more callbacks may be
775 // added in between callback calls, so iterators may become invalidated.
776 // NOLINTNEXTLINE(modernize-loop-convert)
777 for (size_t i = 0; i < final_callbacks_.size(); ++i) {
778 cb_lock.unlock();
779 final_callbacks_[i]();
780 cb_lock.lock();
781 }
782 }
783 }
784
set_exception_without_signal(const std::shared_ptr<Node> & fn)785 void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
786 if (!has_error_.exchange(true)) {
787 if (AnomalyMode::is_enabled() && fn) {
788 fn->metadata()->print_stack(fn->name());
789 }
790 }
791 }
792
set_exception(std::exception_ptr eptr,const std::shared_ptr<Node> & fn)793 void GraphTask::set_exception(
794 std::exception_ptr eptr,
795 const std::shared_ptr<Node>& fn) {
796 set_exception_without_signal(fn);
797 if (!future_completed_.exchange(true)) {
798 future_result_->setError(std::move(eptr));
799 }
800 }
801
call_pre_hooks(Node & fn,variable_list inputs)802 static variable_list call_pre_hooks(Node& fn, variable_list inputs) {
803 for (const auto& hook : fn.pre_hooks()) {
804 inputs = (*hook)(inputs);
805 }
806 return inputs;
807 }
808
call_tensor_pre_hooks(Node & fn,variable_list inputs)809 static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) {
810 for (const auto& hook : fn.tensor_pre_hooks()) {
811 inputs = (*hook)(inputs);
812 }
813 for (const auto& pair : fn.retains_grad_hooks()) {
814 inputs = (*pair.second)(inputs);
815 }
816 return inputs;
817 }
818
call_post_hooks(Node & fn,variable_list outputs,const variable_list & inputs,const bool had_post_hooks)819 static variable_list call_post_hooks(
820 Node& fn,
821 variable_list outputs,
822 const variable_list& inputs,
823 const bool had_post_hooks) {
824 for (const auto& hook : fn.post_hooks()) {
825 if (had_post_hooks) {
826 outputs = (*hook)(outputs, inputs);
827 } else {
828 variable_list null_inputs;
829 outputs = (*hook)(outputs, null_inputs);
830 }
831 }
832 return outputs;
833 }
834
set_device(int device)835 void set_device(int device) {
836 // NB: We MUST NOT construct the guard for device CPU,
837 // as in some settings we compile with cuda, but
838 // have lazy stubs for CUDA functionality (so actually
839 // attempting to setup a guard(CPU_DEVICE) will cause an
840 // error, because it will still query GetDevice).
841 //
842 // Don't use DeviceGuard here because its destructor may be called before the
843 // device is reset. This is fine because the device is thread local.
844 if (device != CPU_DEVICE) {
845 for (const auto i : c10::irange(static_cast<size_t>(
846 c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES))) {
847 auto* impl = c10::impl::device_guard_impl_registry[i].load();
848 if (impl && device < impl->deviceCount()) {
849 impl->setDevice(at::Device(
850 static_cast<c10::DeviceType>(i),
851 static_cast<c10::DeviceIndex>(device)));
852 }
853 }
854 }
855 worker_device = device;
856 }
857
validate_outputs(const edge_list & edges,variable_list & grads,const std::function<std::string (const std::string &)> & format_error)858 void validate_outputs(
859 const edge_list& edges,
860 variable_list& grads,
861 const std::function<std::string(const std::string&)>& format_error) {
862 if (grads.size() != edges.size()) {
863 std::stringstream ss;
864 ss << "invalid number of gradients - expected ";
865 ss << edges.size() << ", but got " << grads.size();
866 TORCH_CHECK(false, format_error(ss.str()));
867 }
868 for (const auto i : c10::irange(grads.size())) {
869 const auto& edge = edges[i];
870 if (!edge.is_valid())
871 continue;
872
873 const auto& metadata = edge.function->input_metadata(edge.input_nr);
874 auto& grad = grads[i];
875 if (!grad.defined()) {
876 // FIXME: TestJit.test_ge_optimized fails this assertion.
877 // std::stringstream ss;
878 // ss << "undefined gradient at index " << i;
879 // TORCH_CHECK(false, format_error(ss.str()));
880 continue;
881 }
882
883 grad = metadata.maybe_reduce(i, std::move(grad), format_error);
884
885 bool input_is_complex =
886 isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));
887 bool grad_is_complex = isComplexType(grad.scalar_type());
888
889 TORCH_CHECK(
890 isFloatingType(grad.scalar_type()) ||
891 (input_is_complex == grad_is_complex));
892 if (c10::typeMetaToScalarType(metadata.options().dtype()) !=
893 grad.scalar_type()) {
894 grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
895 }
896 if (grad.dtype() != metadata.dtype()) {
897 std::stringstream ss;
898 ss << "invalid gradient at index " << i << " - expected dtype ";
899 ss << metadata.dtype() << " but got " << grad.dtype();
900 TORCH_CHECK(false, format_error(ss.str()));
901 }
902 if (grad.layout() != metadata.layout()) {
903 // TODO: Currently we only support (*, Sparse) combination for
904 // (tensor.layout(), tensor.grad.layout()) In future, there will be an
905 // opportunity to support more combinations of layouts if they are
906 // composable (example., operations like addition etc., are well defined
907 // between tensors of different layouts.), as well as all parts of
908 // autograd like AccumulateGrad correctly handle this. We allow grad to be
909 // Strided when metadata is SparseCsr
910 if (!grad.is_sparse() &&
911 !(grad.layout() == at::kStrided &&
912 (at::sparse_csr::is_sparse_compressed(metadata.layout()) ||
913 metadata.layout() == at::kSparse))) {
914 std::stringstream ss;
915 ss << "invalid gradient at index " << i << " - expected layout ";
916 ss << metadata.layout() << " but got " << grad.layout();
917 TORCH_CHECK(false, format_error(ss.str()));
918 }
919 }
920
921 if (grad.device() != metadata.device()) {
922 // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
923 // should be eventually removed
924 if (!(metadata.is_tensor_subclass() ||
925 grad.unsafeGetTensorImpl()->is_python_dispatch())) {
926 if (grad.dim() == 0) {
927 grad = grad.to(metadata.device());
928 } else {
929 std::stringstream ss;
930 ss << "invalid gradient at index " << i << " - expected device ";
931 ss << metadata.device() << " but got " << grad.device();
932 TORCH_CHECK(false, format_error(ss.str()));
933 }
934 }
935 }
936 // We should not build graph for Tensors that are not differentiable
937 TORCH_INTERNAL_ASSERT(isDifferentiableType(grad.scalar_type()));
938 }
939 }
940
call_function(std::shared_ptr<GraphTask> & graph_task,Node * func,InputBuffer & inputBuffer)941 static variable_list call_function(
942 std::shared_ptr<GraphTask>& graph_task,
943 Node* func,
944 InputBuffer& inputBuffer) {
945 CheckpointValidGuard cpvguard(graph_task);
946 auto& fn = *func;
947 auto inputs =
948 call_tensor_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
949 inputs = call_pre_hooks(fn, std::move(inputs));
950 if (!graph_task->keep_graph_) {
951 fn.will_release_variables();
952 }
953
954 const auto has_post_hooks = !fn.post_hooks().empty();
955 variable_list outputs;
956
957 if (has_post_hooks) {
958 // In functions/accumulate_grad.cpp, there is some logic to check the
959 // conditions under which the incoming gradient can be stolen directly
960 // (which elides a deep copy) instead of cloned. One of these conditions
961 // is that the incoming gradient's refcount must be 1 (nothing else is
962 // referencing the same data). Stashing inputs_copy here bumps the
963 // refcount, so if post hooks are employed, it's actually still ok for
964 // accumulate_grad.cpp to steal the gradient if the refcount is 2.
965 //
966 // "new_grad.use_count() <= 1 + !post_hooks().empty()" in
967 // accumulate_grad.cpp accounts for this, but also creates a silent
968 // dependency between engine.cpp (ie, this particular engine
969 // implementation) and accumulate_grad.cpp.
970 //
971 // If you change the logic here, make sure it's compatible with
972 // accumulate_grad.cpp.
973 auto inputs_copy = inputs;
974 outputs = fn(std::move(inputs_copy));
975 } else {
976 outputs = fn(std::move(inputs));
977 }
978
979 validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
980 std::ostringstream ss;
981 ss << "Function " << fn.name() << " returned an " << msg;
982 return ss.str();
983 });
984
985 // NOLINTNEXTLINE(bugprone-use-after-move)
986 return call_post_hooks(fn, std::move(outputs), inputs, has_post_hooks);
987 }
988
evaluate_function(std::shared_ptr<GraphTask> & graph_task,Node * func,InputBuffer & inputs,const std::shared_ptr<ReadyQueue> & cpu_ready_queue)989 void Engine::evaluate_function(
990 std::shared_ptr<GraphTask>& graph_task,
991 Node* func,
992 InputBuffer& inputs,
993 const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
994 // The InputBuffer::adds that supplied incoming grads took pains to
995 // ensure they're safe to consume in the context of the present
996 // func's stream (if applicable). So we guard onto that stream
997 // before working with the grads in any capacity.
998 auto opt_parent_stream = (*func).stream();
999 c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
1000
1001 // If exec_info_ is not empty, we have to instrument the execution
1002 auto& exec_info_ = graph_task->exec_info_;
1003 if (!exec_info_.empty()) {
1004 auto& fn_info = exec_info_.at(func);
1005 variable_list new_inputs = inputs.buffer;
1006 if (!fn_info.needed_) {
1007 // We always want to call tensor pre-hooks, but want to avoid calling it
1008 // twice. needed_ = True indicates that we will call tensor pre-hooks
1009 // later.
1010 //
1011 // See NOTE [Hooks ordering] for more context.
1012 new_inputs = call_tensor_pre_hooks(
1013 *func, InputBuffer::variables(std::move(inputs)));
1014 }
1015 if (auto* capture_vec = fn_info.captures_.get()) {
1016 auto opt_parent_stream = (*func).stream();
1017 // Lock mutex for writing to graph_task->captured_vars_.
1018 std::lock_guard<std::mutex> lock(graph_task->mutex_);
1019 for (const auto& capture : *capture_vec) {
1020 auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
1021 captured_grad = new_inputs[capture.input_idx_];
1022 // NOTE [Deprecated capture hooks]
1023 for (const auto& hook :
1024 capture.DO_NOT_USE_DEPRECATED_get_capture_hooks()) {
1025 captured_grad = (*hook)(captured_grad);
1026 }
1027 if (opt_parent_stream) {
1028 // No need to take graph_task->mutex_ here, we already hold it
1029 graph_task->leaf_streams.emplace(*opt_parent_stream);
1030 }
1031 }
1032 }
1033 if (!fn_info.needed_) {
1034 // Skip execution if we don't need to execute the function.
1035 return;
1036 }
1037 }
1038
1039 auto outputs = call_function(graph_task, func, inputs);
1040
1041 auto& fn = *func;
1042 if (!graph_task->keep_graph_) {
1043 fn.release_variables();
1044 }
1045
1046 auto num_outputs = outputs.size();
1047 if (num_outputs == 0) { // Note: doesn't acquire the mutex
1048 // Records leaf stream (if applicable)
1049 // See Note [Streaming backwards]
1050 if (opt_parent_stream) {
1051 std::lock_guard<std::mutex> lock(graph_task->mutex_);
1052 graph_task->leaf_streams.emplace(*opt_parent_stream);
1053 }
1054 return;
1055 }
1056
1057 if (AnomalyMode::is_enabled() && AnomalyMode::should_check_nan()) {
1058 AutoGradMode grad_mode(false);
1059 for (const auto i : c10::irange(num_outputs)) {
1060 auto& output = outputs[i];
1061 at::OptionalDeviceGuard guard(device_of(output));
1062 if (output.defined() && isnan(output)._is_any_true().item<bool>()) {
1063 std::stringstream ss;
1064 ss << "Function '" << fn.name() << "' returned nan values in its " << i
1065 << "th output.";
1066 throw std::runtime_error(ss.str());
1067 }
1068 }
1069 }
1070
1071 // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and
1072 // cpu_ready_queue_ below
1073 std::lock_guard<std::mutex> lock(graph_task->mutex_);
1074 for (const auto i : c10::irange(num_outputs)) {
1075 auto& output = outputs[i];
1076 const auto& next = fn.next_edge(i);
1077
1078 if (!next.is_valid())
1079 continue;
1080
1081 // Check if the next function is ready to be computed
1082 bool is_ready = false;
1083 auto& dependencies = graph_task->dependencies_;
1084 auto it = dependencies.find(next.function.get());
1085
1086 if (it == dependencies.end()) {
1087 auto name = next.function->name();
1088 throw std::runtime_error(std::string("dependency not found for ") + name);
1089 } else if (--it->second == 0) {
1090 dependencies.erase(it);
1091 is_ready = true;
1092 }
1093
1094 auto& not_ready = graph_task->not_ready_;
1095 auto not_ready_it = not_ready.find(next.function.get());
1096 if (not_ready_it == not_ready.end()) {
1097 // Skip functions that aren't supposed to be executed
1098 if (!exec_info_.empty()) {
1099 auto it = exec_info_.find(next.function.get());
1100 if (it == exec_info_.end() || !it->second.should_execute()) {
1101 continue;
1102 }
1103 }
1104 // No buffers have been allocated for the function
1105 InputBuffer input_buffer(next.function->num_inputs());
1106
1107 // Accumulates into buffer
1108 auto opt_next_stream = next.function->stream();
1109 input_buffer.add(
1110 next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
1111
1112 if (is_ready) {
1113 auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
1114 queue->push(
1115 NodeTask(graph_task, next.function, std::move(input_buffer)));
1116 } else {
1117 not_ready.emplace(next.function.get(), std::move(input_buffer));
1118 }
1119 } else {
1120 // The function already has a buffer
1121 auto& input_buffer = not_ready_it->second;
1122
1123 // Accumulates into buffer
1124 auto opt_next_stream = next.function->stream();
1125 input_buffer.add(
1126 next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
1127 if (is_ready) {
1128 auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
1129 queue->push(
1130 NodeTask(graph_task, next.function, std::move(input_buffer)));
1131 not_ready.erase(not_ready_it);
1132 }
1133 }
1134 }
1135 }
1136
compute_min_topological_nr(const edge_list & outputs)1137 inline static uint64_t compute_min_topological_nr(const edge_list& outputs) {
1138 // Computes the mininum topological number among all the outputs
1139 if (outputs.empty()) {
1140 return 0;
1141 }
1142 auto min_topo_nr = std::numeric_limits<uint64_t>::max();
1143 for (auto& output_edge : outputs) {
1144 auto topo_nr = output_edge.function->topological_nr();
1145 min_topo_nr = (min_topo_nr < topo_nr) ? min_topo_nr : topo_nr;
1146 }
1147 return min_topo_nr;
1148 }
1149
compute_dependencies(Node * root,GraphTask & task,uint64_t min_topo_nr)1150 auto Engine::compute_dependencies(
1151 Node* root,
1152 GraphTask& task,
1153 uint64_t min_topo_nr) -> void {
1154 // Computes the number of dependencies for each function which requires grad
1155 std::vector<Node*> queue{root};
1156 bool will_use_accelerator = false;
1157
1158 // Queue contains all nodes that will start propagating gradients.
1159 // We no longer have to expand functions that don't require grad.
1160 auto& dependencies = task.dependencies_;
1161 while (!queue.empty()) {
1162 auto fn = queue.back();
1163 queue.pop_back();
1164 if (fn->topological_nr() < min_topo_nr) {
1165 continue;
1166 }
1167 if (!will_use_accelerator) {
1168 will_use_accelerator = fn->stream().has_value();
1169 }
1170 for (const auto& edge : fn->next_edges()) {
1171 if (auto next_ptr = edge.function.get()) {
1172 dependencies[next_ptr] += 1;
1173 const bool was_inserted = task.nodes_in_graph_.insert(next_ptr).second;
1174 if (was_inserted)
1175 queue.push_back(next_ptr);
1176 }
1177 }
1178 }
1179
1180 if (will_use_accelerator) {
1181 // Collects current streams for devices where this process has a
1182 // context, so GraphTask::exec_post_processing can sync them with
1183 // leaf_streams.
1184 task.stash_current_streams();
1185 }
1186 }
1187
execute(const edge_list & root_edges,const variable_list & inputs,bool keep_graph,bool create_graph,bool accumulate_grad,const edge_list & outputs)1188 auto Engine::execute(
1189 const edge_list& root_edges,
1190 const variable_list& inputs,
1191 bool keep_graph,
1192 bool create_graph,
1193 bool accumulate_grad,
1194 const edge_list& outputs) -> variable_list {
1195 validate_outputs(
1196 root_edges,
1197 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1198 const_cast<variable_list&>(inputs),
1199 [](const std::string& msg) { return msg; });
1200 if (accumulate_grad && create_graph) {
1201 TORCH_WARN_ONCE(
1202 "Using backward() with create_graph=True will create a reference cycle "
1203 "between the parameter and its gradient which can cause a memory leak. "
1204 "We recommend using autograd.grad when creating the graph to avoid this. "
1205 "If you have to use this function, make sure to reset the .grad fields of "
1206 "your parameters to None after use to break the cycle and avoid the leak.");
1207 }
1208
1209 // Allows us to assert no other threads are in backwards
1210 CompiledAutogradThreadingDebugCheck _thread_check;
1211 auto compiled_autograd = the_compiled_autograd.load();
1212 TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON);
1213
1214 // accumulate_grad is true if and only if the frontend call was to
1215 // backward(), not grad(). grad() returns the sum of the gradients
1216 // w.r.t. the inputs and thus needs the inputs to be present.
1217 TORCH_CHECK_VALUE(
1218 accumulate_grad || !outputs.empty(), "grad requires non-empty inputs.");
1219
1220 // A fresh first time Engine::execute call should start on the CPU device,
1221 // initialize a new thread local ready queue on CPU or reuse the existing one
1222 // (if there is one allocated already, i.e. consecutive backward calls,
1223 // re-entrant backward calls), then memoize the local_ready_queue in GraphTask
1224 init_local_ready_queue();
1225 bool not_reentrant_backward_call = worker_device == NO_DEVICE;
1226
1227 // Store root nodes so we can traverse through the graph later
1228 // e.g., for get_current_graph_task_execution_order
1229 c10::SmallVector<Node*, 4> temp_roots{root_edges.size()};
1230 for (const auto i : c10::irange(root_edges.size())) {
1231 temp_roots[i] = root_edges[i].function.get();
1232 }
1233
1234 auto graph_task = std::make_shared<GraphTask>(
1235 /* keep_graph */ keep_graph,
1236 /* create_graph */ create_graph,
1237 /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
1238 /* cpu_ready_queue */ local_ready_queue,
1239 /* graph_roots */ std::move(temp_roots));
1240
1241 // If we receive a single root, skip creating extra root node
1242 bool skip_dummy_node = root_edges.size() == 1 && compiled_autograd == nullptr;
1243 auto graph_root = skip_dummy_node
1244 ? root_edges.at(0).function
1245 : std::make_shared<GraphRoot>(root_edges, inputs);
1246
1247 auto min_topo_nr = compute_min_topological_nr(outputs);
1248 // Now compute the dependencies for all executable functions
1249 compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);
1250
1251 if (!outputs.empty()) {
1252 graph_task->init_to_execute(
1253 *graph_root, outputs, accumulate_grad, min_topo_nr);
1254 }
1255
1256 if (compiled_autograd != nullptr) {
1257 // see [Note: Compiled Autograd]
1258 TORCH_CHECK(
1259 !create_graph, "compiled_autograd does not support create_graph");
1260 _thread_check.release();
1261 TORCH_CHECK(
1262 !AnomalyMode::is_enabled(),
1263 "compiled_autograd does not support AnomalyMode")
1264 return (*compiled_autograd)(
1265 graph_root, *graph_task, accumulate_grad, outputs);
1266 }
1267
1268 // Queue the root
1269 if (skip_dummy_node) {
1270 InputBuffer input_buffer(root_edges.at(0).function->num_inputs());
1271 auto input = inputs.at(0);
1272
1273 const auto input_stream = InputMetadata(input).stream();
1274 auto opt_next_stream = root_edges.at(0).function->stream();
1275 input_buffer.add(
1276 root_edges.at(0).input_nr,
1277 std::move(input),
1278 input_stream,
1279 opt_next_stream);
1280
1281 execute_with_graph_task(
1282 graph_task, std::move(graph_root), std::move(input_buffer));
1283 } else {
1284 execute_with_graph_task(
1285 graph_task, std::move(graph_root), InputBuffer(variable_list()));
1286 }
1287 // Avoid a refcount bump for the Future, since we check for refcount in
1288 // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1)
1289 // in dist_engine.cpp).
1290 auto& fut = graph_task->future_result_;
1291 fut->wait();
1292 graph_task->warning_handler_.replay_warnings();
1293 return fut->value().toTensorVector();
1294 }
1295
initialize_device_threads_pool()1296 void Engine::initialize_device_threads_pool() {
1297 TORCH_CHECK(
1298 !in_bad_autograd_fork,
1299 "Unable to handle autograd's threading in combination with fork-based multiprocessing. "
1300 "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork");
1301 c10::call_once(
1302 start_device_threads_flag_, &Engine::start_device_threads, this);
1303 }
1304
execute_with_graph_task(const std::shared_ptr<GraphTask> & graph_task,std::shared_ptr<Node> graph_root,InputBuffer && input_buffer)1305 c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
1306 const std::shared_ptr<GraphTask>& graph_task,
1307 std::shared_ptr<Node> graph_root,
1308 InputBuffer&& input_buffer) {
1309 initialize_device_threads_pool();
1310 // Lock mutex for GraphTask.
1311 std::unique_lock<std::mutex> lock(graph_task->mutex_);
1312
1313 auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
1314
1315 // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the
1316 // autograd engine with corresponding GraphTask, and its NOT a re-entrant call
1317 if (worker_device == NO_DEVICE) {
1318 // We set the worker_device to CPU_DEVICE only if worker_device was
1319 // previously NO_DEVICE. Setting it to CPU afterwards allow us to detect
1320 // whether this is a re-entrant call or not.
1321 set_device(CPU_DEVICE);
1322
1323 // set the graph_task owner to the current device
1324 graph_task->owner_ = worker_device;
1325
1326 // Now that all the non-thread safe fields of the graph_task have been
1327 // populated, we can enqueue it.
1328 queue->push(
1329 NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
1330
1331 // The owning thread start to drive the engine execution for any CPU task
1332 // that was just pushed or will be added later from other worker threads
1333 lock.unlock();
1334 thread_main(graph_task);
1335 TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
1336 // reset the worker_device after the completion of the graph_task, this is
1337 // so that the initial state of the engine remains the same across every
1338 // backward() or grad() call, we don't need to reset local_ready_queue as we
1339 // could possibly reuse it for new backward calls.
1340 worker_device = NO_DEVICE;
1341 } else {
1342 // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
1343 // backward call from that device.
1344 graph_task->owner_ = worker_device;
1345
1346 // Now that all the non-thread safe fields of the graph_task have been
1347 // populated, we can enqueue it.
1348 queue->push(
1349 NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
1350
1351 if (current_depth >= max_recursion_depth_) {
1352 // See Note [Reentrant backwards]
1353 // If reached the max depth, switch to a different thread
1354 add_thread_pool_task(graph_task);
1355 } else {
1356 // Total depth needs to be updated only in this codepath, since it is
1357 // not used in the block above (when we call add_thread_pool_task).
1358 // In the codepath above, GraphTask.reentrant_depth_ is used to
1359 // bootstrap total_depth in the other thread.
1360 ++total_depth;
1361
1362 // Get back to work while we wait for our new graph_task to
1363 // complete!
1364 ++current_depth;
1365 lock.unlock();
1366 thread_main(graph_task);
1367 --current_depth;
1368 --total_depth;
1369
1370 // The graph task should have completed and the associated future should
1371 // be marked completed as well since 'thread_main' above is a call
1372 // blocking an autograd engine thread.
1373 TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
1374 }
1375 }
1376 // graph_task_exec_post_processing is done when the Future is marked as
1377 // completed in mark_as_completed_and_run_post_processing.
1378 return graph_task->future_result_;
1379 }
1380
1381 // note that when python is present, this base engine will be overriden
1382 // with a PythonEngine. Because this typically happens before get_default_engine
1383 // is called, this base engine will never be created.
get_base_engine()1384 Engine& Engine::get_base_engine() {
1385 static Engine engine;
1386 return engine;
1387 }
1388
1389 std::atomic<EngineStub> engine_stub(Engine::get_base_engine);
1390
set_default_engine_stub(EngineStub stub)1391 void set_default_engine_stub(EngineStub stub) {
1392 engine_stub.store(stub);
1393 }
1394
get_default_engine()1395 Engine& Engine::get_default_engine() {
1396 return engine_stub.load()();
1397 }
1398
set_compiled_autograd(Engine::compiled_autograd_fn fn)1399 void Engine::set_compiled_autograd(Engine::compiled_autograd_fn fn) {
1400 if (the_compiled_autograd.load() == fn) {
1401 return;
1402 }
1403 auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON);
1404 TORCH_CHECK(
1405 num_threads_in_backwards.load() == 0 && prior != COMPILED_AUTOGRAD_POISON,
1406 "compiled_autograd.enable() requires no threads in backwards()");
1407 the_compiled_autograd.store(fn);
1408 }
1409
queue_callback(std::function<void ()> callback)1410 void Engine::queue_callback(std::function<void()> callback) {
1411 TORCH_CHECK(
1412 current_graph_task,
1413 "Final callbacks can only be installed during backward pass.");
1414
1415 std::lock_guard<std::mutex> lock(current_graph_task->final_callbacks_lock_);
1416 current_graph_task->final_callbacks_.emplace_back(std::move(callback));
1417 }
1418
is_checkpoint_valid()1419 bool Engine::is_checkpoint_valid() {
1420 return checkpoint_valid;
1421 }
1422
init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue)1423 void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
1424 if (ready_queue) {
1425 // if ready_queue provided in the caller, use the caller's ready_queue to
1426 // initialize local_ready_queue
1427 local_ready_queue = std::move(ready_queue);
1428 } else if (!local_ready_queue) {
1429 // otherwise if local_ready_queue not allocated, allocate a new ready_queue
1430 local_ready_queue = std::make_shared<ReadyQueue>();
1431 }
1432 }
1433
1434 // CPU ready queue is per GraphTask, but CUDA device ready queues are shared
1435 // across all graph tasks
ready_queue(std::shared_ptr<ReadyQueue> cpu_ready_queue,at::Device device)1436 auto Engine::ready_queue(
1437 std::shared_ptr<ReadyQueue> cpu_ready_queue,
1438 at::Device device) -> std::shared_ptr<ReadyQueue> {
1439 bool multithreading_disabled =
1440 !c10::AutogradState::get_tls_state().get_multithreading_enabled();
1441 if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) {
1442 // return the cpu ready queue passed in
1443 TORCH_INTERNAL_ASSERT(cpu_ready_queue);
1444 return cpu_ready_queue;
1445 } else {
1446 TORCH_INTERNAL_ASSERT(
1447 0 <= device.index() &&
1448 device.index() <
1449 static_cast<c10::DeviceIndex>(device_ready_queues_.size()));
1450 // See Note [Allocating GPUs to autograd threads]
1451 return device_ready_queues_.at(device.index());
1452 }
1453 }
1454
ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue,int device_index)1455 auto Engine::ready_queue_by_index(
1456 std::shared_ptr<ReadyQueue> cpu_ready_queue,
1457 int device_index) -> std::shared_ptr<ReadyQueue> {
1458 if (device_index == CPU_DEVICE) {
1459 // return the cpu ready queue passed in
1460 TORCH_INTERNAL_ASSERT(cpu_ready_queue);
1461 return cpu_ready_queue;
1462 } else {
1463 TORCH_INTERNAL_ASSERT(
1464 0 <= device_index &&
1465 device_index <
1466 static_cast<c10::DeviceIndex>(device_ready_queues_.size()));
1467 // See Note [Allocating GPUs to autograd threads]
1468 // NB: This function would become obsolete if we truly allocated a CPU
1469 // thread per device, rather than colocate.
1470 return device_ready_queues_.at(device_index);
1471 }
1472 }
1473
start_device_threads()1474 auto Engine::start_device_threads() -> void {
1475 // First always initialize the thread pool for re-entrant threads
1476 thread_pool_shared_ = std::make_shared<ThreadPoolShared>();
1477
1478 // Second, create special threads for each non-CPU device
1479 // See Note [Allocating GPUs to autograd threads]
1480 c10::DeviceIndex num_devices = 0;
1481 for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
1482 auto* impl = impl_atomic.load();
1483 // Only record the number of devices for device that don't run on the
1484 // cpu ready queue.
1485 if (impl && !should_run_in_cpu_ready_queue(impl->type())) {
1486 num_devices = std::max(num_devices, impl->deviceCount());
1487 }
1488 }
1489
1490 // If there are no device except cpu, no need to create worker threads
1491 if (num_devices == 0) {
1492 return;
1493 }
1494
1495 // Since we're about to create threads, forking is not possible anymore
1496 track_bad_autograd_forks();
1497
1498 // allocate one thread for every GPU device (but colocate GPUs of different
1499 // types), and pre-allocate the device_ready_queues_ to ensure safe reading on
1500 // it.
1501 device_ready_queues_ = std::vector<std::shared_ptr<ReadyQueue>>(num_devices);
1502 for (auto& queue : device_ready_queues_) {
1503 queue = std::make_shared<ReadyQueue>();
1504 }
1505
1506 for (const auto i : c10::irange(num_devices)) {
1507 std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true);
1508 t.detach();
1509 }
1510 // Wait for the threads to start
1511 {
1512 std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
1513 while (non_reentrant_device_thread_count_.load() !=
1514 static_cast<uint32_t>(num_devices)) {
1515 non_reentrant_device_thread_condvar_.wait(lk);
1516 }
1517 }
1518 }
1519
add_thread_pool_task(const std::weak_ptr<GraphTask> & graph_task)1520 void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
1521 std::unique_lock<std::mutex> lck(thread_pool_shared_->mutex_);
1522 // There may already be some items on the graphtasks_queue_ added by other
1523 // threads but not enough workers to get to the new task that will be
1524 // added
1525 bool create_thread =
1526 (thread_pool_shared_->num_workers_ <=
1527 thread_pool_shared_->graphtasks_queue_.size());
1528 thread_pool_shared_->graphtasks_queue_.push(graph_task);
1529 // Don't need to be holding the lock while actually creating the thread
1530 lck.unlock();
1531 if (create_thread) {
1532 // If we're creating a new thread, forking is not allowed anymore
1533 track_bad_autograd_forks();
1534 std::thread t(&Engine::reentrant_thread_init, this);
1535 t.detach();
1536 }
1537 // This works even if new thread is created because wait() will test the
1538 // predicate before waiting
1539 thread_pool_shared_->work_.notify_one();
1540 }
1541
1542 // Remembers current streams on all devices where a context has been created for
1543 // This function assumes the accelerator device is available.
stash_current_streams()1544 void GraphTask::stash_current_streams() {
1545 const auto accelerator = at::getAccelerator(true).value();
1546 const auto guard = c10::impl::VirtualGuardImpl{accelerator};
1547 auto num_devices = guard.deviceCount();
1548 caller_current_streams_.resize(num_devices);
1549 if (num_devices > 0) {
1550 for (c10::DeviceIndex idx = 0; idx < num_devices; idx++) {
1551 if (at::globalContext().getAcceleratorHooksInterface().hasPrimaryContext(
1552 idx)) {
1553 caller_current_streams_[idx] = guard.getStream({accelerator, idx});
1554 } else {
1555 caller_current_streams_[idx] = std::nullopt;
1556 }
1557 }
1558 }
1559 }
1560
init_to_execute(Node & graph_root,const edge_list & outputs,bool accumulate_grad,uint64_t min_topo_nr)1561 void GraphTask::init_to_execute(
1562 Node& graph_root,
1563 const edge_list& outputs,
1564 bool accumulate_grad,
1565 uint64_t min_topo_nr) {
1566 // Populates exec_info so nodes that should be executed have
1567 // `exec_info[node].needed_ = true` Only nodes that have a path to any edge in
1568 // `outputs` should be executed. The code below populates exec_info using
1569 // recursion, but the actual code does this iteratively. Refer to the
1570 // numbering to see how the actual code corresponds. A difference to note is
1571 // that in the iterative version, when you are working with the current Node,
1572 // you are responsible to update your parent's is_needed after all your
1573 // children have been updated.
1574 //
1575 // is_needed = {fn: True for fn in outputs} # (0)
1576 // seen = {}
1577 // def compute_is_needed(fn):
1578 // for next_edge in fn.next_edges:
1579 // child_fn = next_edge.fn
1580 // if child_fn in seen and is_needed[child_fn]: # (1)
1581 // is_needed[fn] = true
1582 // else:
1583 // seen.add(child_fn)
1584 // if compute_is_needed(child_fn):
1585 // is_needed[fn] = true # (2)
1586 // # (3) exit for-loop
1587 // return is_needed[fn]
1588 // compute_is_needed(graph_root)
1589 //
1590 // NB: you might be wondering why we don't populate `seen` with outputs. We
1591 // cannot because in the case where two outputs lie on the same path, we still
1592 // need to explore past the first output or we would miss the nodes that are
1593 // required to compute the second output.
1594 int output_idx = 0;
1595 for (auto& output_edge : outputs) {
1596 // (0) `is_needed` above corresponds to `exec_info_[fn].needed_`
1597 Node* output = output_edge.function.get();
1598 auto& info = exec_info_[output];
1599 if (accumulate_grad) {
1600 // if called through `.backward()` we directly set `needed_` for all the
1601 // outputs to true
1602 info.needed_ = true;
1603 } else {
1604 // otherwise it is `.grad()` and we set exec_info[fn].captures_ instead
1605 // In terms of populating the rest of exec_info though, you can basically
1606 // think of this as the same as setting `needed_` is true directly.
1607 if (!info.captures_) {
1608 info.captures_ = std::make_unique<std::vector<ExecInfo::Capture>>();
1609 }
1610 info.captures_->emplace_back(output_edge.input_nr, output_idx++);
1611 }
1612 }
1613 captured_vars_.resize(output_idx);
1614
1615 struct Frame {
1616 Frame(Node* fn) : fn_(fn) {}
1617 Node* fn_{};
1618 size_t next_next_fn_{};
1619
1620 Node* get_next_fn() {
1621 const auto& next = fn_->next_edges();
1622 auto num_next = next.size();
1623 while (next_next_fn_ < num_next) {
1624 auto fn = next[next_next_fn_++].function.get();
1625 if (fn)
1626 return fn;
1627 }
1628 return nullptr;
1629 }
1630 };
1631
1632 auto nodeShouldExecute = [this](Node* fn) {
1633 auto it = exec_info_.find(fn);
1634 return it != exec_info_.end() && it->second.should_execute();
1635 };
1636
1637 std::vector<Frame> stack;
1638 std::unordered_set<Node*> seen;
1639 stack.emplace_back(&graph_root);
1640 exec_info_.emplace(stack.back().fn_, ExecInfo());
1641
1642 while (!stack.empty()) {
1643 auto& frame = stack.back();
1644 const auto fn = frame.fn_;
1645
1646 Node* child_fn = nullptr;
1647 while ((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) {
1648 // (1) next child exists AND has already been seen
1649 if (nodeShouldExecute(child_fn)) {
1650 exec_info_[fn].needed_ = true;
1651 }
1652 }
1653
1654 if (child_fn) {
1655 // (2) next child exists but has not been seen
1656 if (child_fn->topological_nr() < min_topo_nr) {
1657 // child created before the first output means this child cannot have
1658 // an edge to output
1659 continue;
1660 }
1661 stack.emplace_back(child_fn);
1662 } else {
1663 // (3) no next child exists for `fn` means its `needed` has already been
1664 // finalized. pop stack and update parent
1665 stack.pop_back();
1666 if (nodeShouldExecute(fn) && !stack.empty()) {
1667 exec_info_[stack.back().fn_].needed_ = true;
1668 }
1669 }
1670 }
1671 }
1672
1673 } // namespace torch::autograd
1674