xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/engine.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/engine.h>
2 
3 #include <torch/csrc/autograd/anomaly_mode.h>
4 #include <torch/csrc/autograd/autograd.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/autograd/functions/basic_ops.h>
7 #include <torch/csrc/autograd/grad_mode.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/dynamo/compiled_autograd.h>
10 
11 #include <ATen/DeviceAccelerator.h>
12 #include <ATen/DeviceGuard.h>
13 #include <ATen/ExpandUtils.h>
14 #include <ATen/Parallel.h>
15 #include <ATen/SparseCsrTensorUtils.h>
16 #include <ATen/detail/CUDAHooksInterface.h>
17 #include <ATen/detail/PrivateUse1HooksInterface.h>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 #include <ATen/ops/isnan.h>
23 #endif
24 
25 #include <c10/core/DeviceGuard.h>
26 #include <c10/core/Event.h>
27 #include <c10/core/Stream.h>
28 #include <c10/core/StreamGuard.h>
29 #include <c10/util/AbortHandler.h>
30 #include <c10/util/Exception.h>
31 #include <c10/util/ThreadLocal.h>
32 #include <c10/util/irange.h>
33 #include <c10/util/thread_name.h>
34 
35 #include <atomic>
36 #include <chrono>
37 #include <cstdint>
38 #include <functional>
39 #include <memory>
40 #include <mutex>
41 #include <optional>
42 #include <string>
43 #include <thread>
44 #include <unordered_set>
45 #include <utility>
46 
47 namespace torch::autograd {
48 
49 namespace {
50 static bool in_bad_autograd_fork =
51     false; // True for children forked after engine's thread pool init
52 
53 // Called in the forked child if engine's thread pool has already been
54 // initialized
forked_autograd_child()55 static void forked_autograd_child() {
56   in_bad_autograd_fork = true;
57 }
58 
59 // Should be called before unsafe for forks (thread pool) calls
track_bad_autograd_forks()60 static void track_bad_autograd_forks() {
61 #if !defined(WIN32)
62   static c10::once_flag flag;
63   c10::call_once(
64       flag, [&] { pthread_atfork(nullptr, nullptr, forked_autograd_child); });
65 #endif
66 }
67 
should_run_in_cpu_ready_queue(c10::DeviceType device)68 inline bool should_run_in_cpu_ready_queue(c10::DeviceType device) {
69   if (device == c10::kCPU || device == c10::kMeta || device == c10::kLazy) {
70     return true;
71   } else {
72     return false;
73   }
74 }
75 
76 std::atomic<Engine::compiled_autograd_fn> the_compiled_autograd = nullptr;
77 #define COMPILED_AUTOGRAD_POISON \
78   reinterpret_cast<Engine::compiled_autograd_fn>(1)
79 std::atomic<int32_t> num_threads_in_backwards;
80 struct CompiledAutogradThreadingDebugCheck {
CompiledAutogradThreadingDebugChecktorch::autograd::__anon6c5fbed50111::CompiledAutogradThreadingDebugCheck81   CompiledAutogradThreadingDebugCheck() {
82     num_threads_in_backwards++;
83   }
~CompiledAutogradThreadingDebugChecktorch::autograd::__anon6c5fbed50111::CompiledAutogradThreadingDebugCheck84   ~CompiledAutogradThreadingDebugCheck() {
85     release();
86   }
releasetorch::autograd::__anon6c5fbed50111::CompiledAutogradThreadingDebugCheck87   void release() {
88     if (std::exchange(incremented, false)) {
89       num_threads_in_backwards--;
90     }
91   }
92 
93  private:
94   bool incremented{true};
95 };
96 
97 } // namespace
98 
99 // Threads spawned by the engine are assigned a 'worker_device' specifying
100 // what device they process work for. This variable is initialized at:
101 // 1. thread creation time for CUDA, XLA device threads, as they are
102 //    spinning threads waiting for works on their device.
103 // 2. before the graph task execution for CPU threads, as for each
104 //    backward call we use the caller thread to drive engine execution.
105 // This is used when handling reentrant backwards calls;
106 // See Note [Reentrant backwards]
107 static thread_local int worker_device = NO_DEVICE;
108 
109 // This variable is true if ALL invocations in the stack of re-entrant engine
110 // invocations are imperative backwards. This special variable is needed for the
111 // gradient checkpointing feature only.
112 static thread_local bool checkpoint_valid = true;
113 
114 // Number of nested reentrant backwards calls currently on this thread
115 static thread_local int current_depth = 0;
116 
117 // For all device threads (i.e. CUDA, XLA), total_depth represents the total
118 // nested
119 //   reentrant backwards depths over all device threads.
120 // For CPU devices, it is the total depth associated with the original backward
121 // call.
122 static thread_local int total_depth = 0;
123 
124 // The current GraphTask being executed by this thread. This helps
125 // queue_callback() to find the target GraphTask to append final callbacks.
126 C10_DEFINE_TLS_static(std::shared_ptr<GraphTask>, tls_current_graph_task);
127 #define current_graph_task (tls_current_graph_task.get())
128 
129 // Every autograd worker thread is associated with a ready queue, which
130 // specifies the stream of work of this thread to do. This shared_ptr is a
131 // thread_local pointer to each thread's ready_queue, and it should be
132 // initialized via the Engine::init_local_ready_queue() call in each
133 // corresponding thread before execution.
134 //
135 // The CUDA, XLA threads are shared among all invocations of backwards via
136 // device_ready_queues_, while the caller thread is dedicated to processing work
137 // for devices returning true in should_run_in_cpu_ready_queue (most notably the
138 // CPU device). So any given graph task maintains its own cpu_ready_queue_ where
139 // you should send work for it to be done.
140 //
141 // For reentrant backward calls, if we spawn new thread from the current thread
142 // because we reached the maximum depth, the new thread will just reuse the same
143 // ReadyQueue with the parent thread for performance improvement.
144 // see Note [Reentrant backwards] for more details.
145 C10_DEFINE_TLS_static(std::shared_ptr<ReadyQueue>, tls_local_ready_queue);
146 #define local_ready_queue (tls_local_ready_queue.get())
147 
148 // Note [Reentrant backwards]
149 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
150 // To understand the reentrant backwards problem, we have to notice two
151 // aspects of how the autograd engine is implemented today:
152 //
153 //  1. When you call Engine::execute(), you want to block until
154 //  differentiation finishes so that you can get the final result variables
155 //  of the backwards pass.
156 //
157 //  2. The engine operates by having a single worker thread per work queue,
158 //  and every work queue is pinned to a specific device where the
159 //  operation is executed.
160 //
161 // The problem is, suppose that you call backward() inside of a worker
162 // thread.  By property (1), we're supposed to block until the nested task
163 // finishes.  However, by property (2), this worker thread is on the
164 // hook for processing the tasks assigned to it; we better not block,
165 // because then all of our backward executions (including the one we
166 // just started) will deadlock!
167 //
168 // We maintain a pool of threads waiting for work to do
169 // When a reentrant backwards call occurs, the current thread blocks
170 // and a thread from the pool is woken up to complete the blocking tasks and an
171 // any other tasks that would have been assigned to that worker. If there are no
172 // threads available, a new thread is spawned. The new thread will continue
173 // processing tasks from the same ReadyQueue as the parent worker
174 //
175 // When the GraphTask is finished, the parent worker thread that is waiting on
176 // the task is notified and the current thread returns to the pool.
177 
178 // Note [Streaming backwards]
179 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
180 // On CUDA/privateuse1 devices the autograd engine's device operations are run
181 // on the same stream that ran them in forward. This requires automatically
182 // syncing the streams so that function A finishes producing its
183 // output before function B consumes it.
184 //
185 // This synchronization occurs when outputs are placed into input buffers.
186 // The functions corresponding to input buffer positions have metadata
187 // recording their streams from forward, and during backward this
188 // data is used to sync the producer's stream with the consumer's.
189 //
190 // When a CUDA/privateuse1 function is run either all its inputs were
191 // accumulated on the stream used to run the function OR the inputs are on
192 // different devices and the function is responsible for properly acquiring
193 // them.
194 //
195 // User-facing stream semantics of a backward() (or torch.autograd.grad())
196 // call with respect to surrounding ops are the same as for any other call.
197 // See "Stream semantics of backward passes" on
198 // https://pytorch.org/docs/stable/notes/cuda.html
199 //
200 // Internally, backward() runs ops (including leaf nodes) on side threads.
201 // And streams are thread local. So GraphTask achieves the above semantics by
202 //  1. remembering the current streams on all active CUDA/privateuse1 devices
203 //     in the user-facing thread (aka, the thread that called execute() to
204 //     launch the GraphTask)
205 //  2. remembering the "leaf streams" (streams each backward leaf node ran on)
206 //  3. during exec_post_processing, for each leaf stream, sync the remembered
207 //     current streams (on the leaf stream's device) with that
208 //     leaf stream.
209 
getReentrantDepth() const210 int NodeTask::getReentrantDepth() const {
211   std::shared_ptr<GraphTask> graph_task = base_.lock();
212   if (graph_task) {
213     return graph_task->reentrant_depth_;
214   } else {
215     // The graph task is no longer valid indicating an error. As a result, we
216     // try to move this to the front of the queue to ensure the autograd
217     // engine threads pick up this error soon.
218     return std::numeric_limits<int>::max();
219   }
220 }
221 
CheckpointValidGuard(const std::shared_ptr<const GraphTask> & graph_task)222 CheckpointValidGuard::CheckpointValidGuard(
223     const std::shared_ptr<const GraphTask>& graph_task)
224     : prev_checkpoint_valid_state(checkpoint_valid) {
225   checkpoint_valid =
226       graph_task->can_checkpoint() && prev_checkpoint_valid_state;
227 }
228 
~CheckpointValidGuard()229 CheckpointValidGuard::~CheckpointValidGuard() {
230   checkpoint_valid = prev_checkpoint_valid_state;
231 }
232 
push(NodeTask item,bool incrementOutstandingTasks)233 auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
234   {
235     // Lock mutex for writing to heap_
236     std::lock_guard<std::mutex> lock(mutex_);
237     if (incrementOutstandingTasks) {
238       std::shared_ptr<GraphTask> graph_task = item.base_.lock();
239       TORCH_INTERNAL_ASSERT(graph_task, "GraphTask is no longer valid!");
240       ++graph_task->outstanding_tasks_;
241     }
242     heap_.push(std::move(item));
243   }
244   not_empty_.notify_one();
245 }
246 
pushShutdownTask()247 auto ReadyQueue::pushShutdownTask() -> void {
248   {
249     std::lock_guard<std::mutex> lock(mutex_);
250     heap_.push(NodeTask({}, nullptr, InputBuffer(0), true));
251   }
252   not_empty_.notify_one();
253 }
254 
size() const255 size_t ReadyQueue::size() const {
256   // Lock mutex for accesses to heap_
257   std::unique_lock<std::mutex> lock(mutex_);
258   return heap_.size();
259 }
260 
pop()261 auto ReadyQueue::pop() -> NodeTask {
262   // Lock mutex for accesses to heap_
263   std::unique_lock<std::mutex> lock(mutex_);
264   not_empty_.wait(lock, [this] { return !heap_.empty(); });
265   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
266   auto task = std::move(const_cast<NodeTask&>(heap_.top()));
267   heap_.pop();
268   return task;
269 }
270 
empty() const271 bool ReadyQueue::empty() const {
272   // Lock mutex for accesses to heap_
273   std::unique_lock<std::mutex> lock(mutex_);
274   return heap_.empty();
275 }
276 
Engine()277 Engine::Engine()
278     : max_recursion_depth_(MAX_DEPTH), non_reentrant_device_thread_count_(0) {}
279 
~Engine()280 Engine::~Engine() {
281   stop();
282 }
283 
284 // Send shutdown tasks to all device_ready_queues_ if no backward tasks are
285 // running Even though readyQueue should be empty, shutdown tasks have the
286 // highest priority
stop()287 void Engine::stop() {
288   if (stopped_) {
289     return;
290   }
291   stopped_ = true;
292   // Under some conditions, autograd threads can hang on shutdown
293   // Do not wait for them to shutdown indefinitely but rely on timeout
294   auto wait_duration_str = getenv("TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT");
295   auto wait_duration = wait_duration_str ? std::atof(wait_duration_str) : 10.0;
296   bool noBackward = true;
297   for (auto& queue : device_ready_queues_) {
298     noBackward = noBackward && queue->empty();
299   }
300   if (noBackward && wait_duration > 0.0f) {
301     for (auto& queue : device_ready_queues_) {
302       queue->pushShutdownTask();
303     }
304     // Do not wait for termination of global threads on Windows
305     // Because CRT terminates DLL threads before calling
306     // global object destructors
307 #if !defined(_WIN32) || defined(C10_USE_MSVC_STATIC_RUNTIME)
308 
309     using namespace std::chrono_literals;
310     // Set a deadline for how long it is OK to wait device threads to shutdown
311     auto wait_deadline =
312         std::chrono::steady_clock::now() + wait_duration * 1.0s;
313     std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
314     while (non_reentrant_device_thread_count_.load() != 0) {
315       if (non_reentrant_device_thread_condvar_.wait_until(lk, wait_deadline) ==
316           std::cv_status::timeout) {
317         break;
318       }
319     }
320 #endif
321   }
322   // Otherwise threads are leaked
323 }
324 
release_workers()325 void Engine::release_workers() {
326   std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
327   non_reentrant_device_thread_count_.store(0);
328   non_reentrant_device_thread_condvar_.notify_one();
329 }
330 
increment_non_reentrant_thread_count()331 void Engine::increment_non_reentrant_thread_count() {
332   std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
333   non_reentrant_device_thread_count_.fetch_add(1);
334   non_reentrant_device_thread_condvar_.notify_one();
335 }
336 
decrement_non_reentrant_thread_count()337 void Engine::decrement_non_reentrant_thread_count() {
338   std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
339   non_reentrant_device_thread_count_.fetch_sub(1);
340   non_reentrant_device_thread_condvar_.notify_one();
341 }
342 
thread_init(int device,const std::shared_ptr<ReadyQueue> & ready_queue,bool should_increment)343 void Engine::thread_init(
344     int device,
345     const std::shared_ptr<ReadyQueue>& ready_queue,
346     bool should_increment) {
347   // pthread_setname_np restricts the name to 16 characters including
348   // the null byte.
349   std::string thread_name = "pt_autograd_" + std::to_string(device);
350   c10::setThreadName(thread_name);
351 
352   c10::set_terminate_handler();
353   if (should_increment) {
354     increment_non_reentrant_thread_count();
355   }
356 
357   at::init_num_threads();
358 
359   // Note [Allocating GPUs to autograd threads]
360   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
361   // What's our strategy here?  Originally, the autograd engine was written
362   // with only CUDA in mind.  We allocate one thread to handle all CPU
363   // operations, and a thread per CUDA device.
364   //
365   // But what if we have OTHER devices?  There are two plausible
366   // strategies:
367   //
368   //  - We can allocate threads equal to max(num_cuda_devices, num_xla_devices,
369   //    ...) and colocate cuda device 0 with xla device 0
370   //  - We can allocate threads equal to sum(num_cuda_devices, num_xla_devices,
371   //    ...) keeping everyone separate.
372   //
373   // We don't have any good reason to prefer one or the other, so we've
374   // arbitrarily picked to colocate devices.  Maybe the other approach is
375   // better.
376   worker_device = device;
377 
378   // initialize each device thread's thread local ready queue with the ready
379   // queue that is created before the thread initialization
380   init_local_ready_queue(ready_queue);
381 
382   std::shared_ptr<GraphTask> graph_task = nullptr;
383   thread_main(graph_task);
384   if (should_increment) {
385     // Decrement the count during shutdown if we incremented earlier.
386     decrement_non_reentrant_thread_count();
387   }
388 }
389 
GraphTaskGuard(std::shared_ptr<GraphTask> graph_task)390 GraphTaskGuard::GraphTaskGuard(std::shared_ptr<GraphTask> graph_task)
391     : last_graph_task_(std::move(current_graph_task)) {
392   current_graph_task = std::move(graph_task);
393 }
~GraphTaskGuard()394 GraphTaskGuard::~GraphTaskGuard() {
395   restore_current_graph_task();
396 }
397 
restore_current_graph_task()398 void GraphTaskGuard::restore_current_graph_task() {
399   current_graph_task = std::move(last_graph_task_);
400 }
401 
402 // The current graph task's exec_info is being used to trim unnecessary edegs
403 // during node evaluation, see `Node.task_should_compute_output()` function.
404 const std::unordered_map<Node*, GraphTask::ExecInfo>*
get_current_graph_task_exec_info()405 get_current_graph_task_exec_info() {
406   return current_graph_task ? &current_graph_task->exec_info_ : nullptr;
407 }
408 
get_current_graph_task_nodes_in_graph()409 const std::unordered_set<Node*>* get_current_graph_task_nodes_in_graph() {
410   return current_graph_task ? &current_graph_task->nodes_in_graph_ : nullptr;
411 }
412 
get_current_graph_task_id()413 int get_current_graph_task_id() {
414   return current_graph_task ? current_graph_task->id_ : -1;
415 }
416 
get_current_graph_task_keep_graph()417 bool get_current_graph_task_keep_graph() {
418   return current_graph_task ? current_graph_task->keep_graph_ : true;
419 }
420 
add_node_to_current_graph_task_exec_info(Node * fn)421 void add_node_to_current_graph_task_exec_info(Node* fn) {
422   current_graph_task->exec_info_[fn].needed_ = true;
423 }
424 
425 // NB: The engine itself does not use the outputs of this function.
get_current_graph_task_execution_order()426 std::vector<Node*> get_current_graph_task_execution_order() {
427   std::shared_ptr<GraphTask> task = current_graph_task;
428   if (!task) {
429     return {};
430   }
431 
432   // We could potentially check if there is only a single device here
433   // but explicitly require this context doesn't seem bad either
434   TORCH_CHECK(
435       !c10::AutogradState::get_tls_state().get_multithreading_enabled(),
436       "get_current_graph_task_execution_order expects the current backward to be "
437       "executed with multithreading disabled, e.g. by running:\n\n"
438       ">>> with torch.autograd.set_multithreading_enabled(False):\n"
439       "...     torch.autograd.grad(...)\n");
440 
441   const bool check_exec_info = !task->exec_info_.empty();
442   std::vector<Node*> out{};
443   // Do a copy since we mutate it later
444   std::unordered_map<Node*, int> dependencies = task->dependencies_;
445 
446   auto compare_seq_nr = [](Node* n1, Node* n2) {
447     return n1->sequence_nr() < n2->sequence_nr();
448   };
449   std::priority_queue<Node*, std::vector<Node*>, decltype(compare_seq_nr)> heap(
450       compare_seq_nr);
451 
452   for (Node* ptr : task->graph_roots_) {
453     heap.push(ptr);
454   }
455 
456   // Implementation notes:
457   // - We need count dependencies even though we have sequence_nr, because
458   //   in the accumulate_grad case we cannot assume the outputs to have higher
459   //   sequence_nr than the inputs
460   // - Don't need to check topological_nr because we have exec_info
461   while (!heap.empty()) {
462     Node* fn = heap.top();
463     heap.pop();
464 
465     out.push_back(fn);
466     for (const auto& edge : fn->next_edges()) {
467       Node* next_ptr = edge.function.get();
468       if (!next_ptr) {
469         continue;
470       }
471       if (check_exec_info) {
472         auto it = task->exec_info_.find(next_ptr);
473         if (it == task->exec_info_.end() || !it->second.should_execute()) {
474           continue;
475         }
476       }
477       auto it = dependencies.find(edge.function.get());
478       TORCH_INTERNAL_ASSERT(it != dependencies.end());
479       if (--it->second == 0) {
480         dependencies.erase(it);
481         heap.push(next_ptr);
482       }
483     }
484   }
485   return out;
486 }
487 
488 // NOTE: graph_tasks do not necessarily form a stack. Imagine this
489 // case:
490 //
491 //    +----> Eval1
492 //  Root
493 //    +----> Eval2
494 //
495 // Once Root is executed, both Eval1 and Eval2 are added to the ready queue.
496 // Next, Eval1 is run and this causes the worker to enter thread_main again.
497 // Then, it pops the next task from the queue, but at this point it is Eval2.
498 // It enters thread_main once again, but now with graph_task of Eval2, which is
499 // completely unrelated to that of Eval1 (it's not a recursive call).
500 // It's all ok and is handled right now, but it should be accounted for
501 // in case this code is to be changed.
502 //
503 // thread_main is used by:
504 // 1). autograd threads for devices (i.e. CUDA, XLA)
505 // 2). the caller/owning thread of the backward call on CPU (sync mode)
506 // 3). Renetrant backward that invoked by either 1) or 2)
507 // The exit conditions are different for the above three cases.
508 // For 1), we are spinning on running the thread_main on device autograd
509 //         threads throughout the Engine lifetime, thread_main will get
510 //         terminated during Engine destruction by pushing shutdown tasks
511 // For 2), the owning thread of the backward call drives the thread_main
512 //         synchronously until the graph_task of that owning thread is
513 //         completed and exit the thread_main to continue executing the
514 //         result of caller's code.
515 // For 3), the reentrant backward that invokes
516 //         thread_main, either from 1) or 2), will not spin and will exit as
517 //         long as graph_task is completed and notify the owning thread as
518 //         needed.
thread_main(const std::shared_ptr<GraphTask> & graph_task)519 auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
520   // When graph_task is nullptr, this is a long running thread that processes
521   // tasks (ex: device threads). When graph_task is non-null (ex: reentrant
522   // backwards, user thread), this function is expected to exit once that
523   // graph_task complete.
524 
525   // local_ready_queue should already been initialized when we get into
526   // thread_main
527   TORCH_INTERNAL_ASSERT(local_ready_queue != nullptr);
528   while (graph_task == nullptr || !graph_task->future_result_->completed()) {
529     // local_graph_task represents the graph_task we retrieve from the queue.
530     // The outer graph_task represents the overall graph_task we need to execute
531     // for reentrant execution.
532     std::shared_ptr<GraphTask> local_graph_task;
533     {
534       // Scope this block of execution since NodeTask is not needed after this
535       // block and can be deallocated (release any references to grad tensors
536       // as part of inputs_).
537       NodeTask task = local_ready_queue->pop();
538       // This will only work if the worker is running a non backward task
539       // TODO Needs to be fixed this to work in all cases
540       if (task.isShutdownTask_) {
541         C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown");
542         break;
543       }
544 
545       local_graph_task = task.base_.lock();
546       if (!local_graph_task) {
547         // GraphTask for function is no longer valid, skipping further
548         // execution.
549         continue;
550       }
551 
552       set_device(worker_device);
553 
554       if (task.fn_ && !local_graph_task->has_error_.load()) {
555         // Set the ThreadLocalState before calling the function.
556         // NB: The ThreadLocalStateGuard doesn't set the grad_mode because
557         // GraphTask always saves ThreadLocalState without grad_mode.
558         at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
559         c10::WarningUtils::WarningHandlerGuard warnings_guard(
560             &local_graph_task->warning_handler_);
561 
562         try {
563           // The guard sets the thread_local current_graph_task on construction
564           // and restores it on exit. The current_graph_task variable helps
565           // queue_callback() to find the target GraphTask to append final
566           // callbacks.
567           GraphTaskGuard guard(local_graph_task);
568           NodeGuard ndguard(task.fn_);
569           {
570             RECORD_FUNCTION(
571                 c10::str(
572                     "autograd::engine::evaluate_function: ",
573                     task.fn_.get()->name()),
574                 c10::ArrayRef<const c10::IValue>());
575             evaluate_function(
576                 local_graph_task,
577                 task.fn_.get(),
578                 task.inputs_,
579                 local_graph_task->cpu_ready_queue_);
580           }
581         } catch (std::exception& e) {
582           // See Note [ Persisting PyErr state across autograd engine threads ]
583           thread_on_exception(local_graph_task, task.fn_, e);
584         }
585       }
586     }
587 
588     // Decrement the outstanding tasks.
589     --local_graph_task->outstanding_tasks_;
590 
591     // Check if we've completed execution.
592     if (local_graph_task->completed()) {
593       local_graph_task->mark_as_completed_and_run_post_processing();
594 
595       auto base_owner = local_graph_task->owner_;
596       // The current worker thread finish the graph_task, but the owning thread
597       // of the graph_task might be sleeping on pop() if it does not have work.
598       // So we need to send a dummy function task to the owning thread just to
599       // ensure that it's not sleeping, so that we can exit the thread_main.
600       // If it has work, it might see that graph_task->outstanding_tasks_ == 0
601       // before it gets to the task, but it's a no-op anyway.
602       //
603       // NB: This is not necessary if the current thread is the owning thread.
604       if (worker_device != base_owner) {
605         // Synchronize outstanding_tasks_ with queue mutex
606         std::atomic_thread_fence(std::memory_order_release);
607         ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
608             ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
609       }
610     }
611   }
612 }
613 
614 // Reentrant call will re-use the graph_task's owner thread ready_queue for
615 // queueing tasks (NOTE: this is not true in the async_mode of the engine).
616 // While we can create separate ready queue for each new reentrant
617 // thread, but sharing the same cpu_ready_queue with parent thread is a
618 // performance improvement and cuda thread still have to do the same thing.
reentrant_thread_init()619 void Engine::reentrant_thread_init() {
620   c10::set_terminate_handler();
621   at::init_num_threads();
622   auto tp_shared = thread_pool_shared_;
623   while (true) {
624     std::unique_lock<std::mutex> lk(tp_shared->mutex_);
625     ++thread_pool_shared_->num_workers_;
626     tp_shared->work_.wait(
627         lk, [&tp_shared] { return !tp_shared->graphtasks_queue_.empty(); });
628     --thread_pool_shared_->num_workers_;
629     auto task = tp_shared->graphtasks_queue_.front();
630     tp_shared->graphtasks_queue_.pop();
631     lk.unlock();
632     std::shared_ptr<GraphTask> graph_task = task.lock();
633     if (!graph_task) {
634       LOG(INFO) << "GraphTask has expired, skipping reentrant execution";
635       continue;
636     }
637     set_device(graph_task->owner_);
638     // set the local_ready_queue to the ready queue on the graph_task->owner_
639     // device
640     local_ready_queue =
641         ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_);
642     total_depth = graph_task->reentrant_depth_;
643     thread_main(graph_task);
644   }
645 }
646 
thread_on_exception(const std::shared_ptr<GraphTask> & graph_task,const std::shared_ptr<Node> & fn,std::exception & e)647 void Engine::thread_on_exception(
648     const std::shared_ptr<GraphTask>& graph_task,
649     const std::shared_ptr<Node>& fn,
650     std::exception& e) {
651   graph_task->set_exception(std::current_exception(), fn);
652 }
653 
654 namespace {
655 std::atomic<uint64_t> graph_task_id{0};
656 }
657 
GraphTask(bool keep_graph,bool grad_mode,int reentrant_depth,std::shared_ptr<ReadyQueue> cpu_ready_queue,c10::SmallVector<Node *,4> graph_roots,bool exit_on_error)658 GraphTask::GraphTask(
659     bool keep_graph,
660     bool grad_mode,
661     int reentrant_depth,
662     std::shared_ptr<ReadyQueue> cpu_ready_queue,
663     c10::SmallVector<Node*, 4> graph_roots,
664     bool exit_on_error)
665     : keep_graph_(keep_graph),
666       graph_roots_(std::move(graph_roots)),
667       owner_(NO_DEVICE),
668       reentrant_depth_(reentrant_depth),
669       exit_on_error_(exit_on_error),
670       cpu_ready_queue_(std::move(cpu_ready_queue)),
671       future_result_(c10::make_intrusive<at::ivalue::Future>(
672           c10::ListType::create(c10::TensorType::get()))),
673       id_(graph_task_id.fetch_add(1, std::memory_order_relaxed)) {
674   thread_locals_.set_grad_mode(grad_mode);
675 }
676 
completed()677 bool GraphTask::completed() {
678   return outstanding_tasks_.load() == 0 ||
679       (exit_on_error_ && has_error_.load());
680 }
681 
mark_as_completed_and_run_post_processing()682 void GraphTask::mark_as_completed_and_run_post_processing() {
683   // Allow only one thread one attempt to process this logic.
684   if (future_completed_.exchange(true)) {
685     // Future is already marked complete, or being marked as such.
686     // In case the marking complete is only in progress, we add a
687     // wait() to guarantee the future is marked complete on exit.
688     future_result_->wait();
689     return;
690   }
691 
692   try {
693     // Run post processing, before marking the future as complete.
694     // Drop lock prior to completing, to avoid holding across callbacks.
695     std::unique_lock<std::mutex> lock(mutex_);
696 
697     exec_post_processing();
698     std::vector<Variable> vars = std::move(captured_vars_);
699 
700     // Need to unlock before we call markCompleted to avoid holding locks
701     // when the callbacks are called.
702     lock.unlock();
703     future_result_->markCompleted(vars);
704   } catch (std::exception&) {
705     future_result_->setErrorIfNeeded(std::current_exception());
706   }
707 }
708 
exec_post_processing()709 void GraphTask::exec_post_processing() {
710   if (!not_ready_.empty()) {
711     throw std::runtime_error("could not compute gradients for some functions");
712   }
713 
714   // set the thread_local current_graph_task_ as more callbacks can be installed
715   // by existing final callbacks.
716   GraphTaskGuard guard(shared_from_this());
717   // Lock mutex during each iteration for accessing final_callbacks.size()
718   // Unlocking is necessary, because the callback can register
719   // more callbacks (or they can be registered from other threads
720   // while it's waiting.
721   std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
722 
723   // caller_current_streams_ with nullopt entries removed
724   std::vector<c10::Stream> caller_current_streams_filtered;
725 
726   // See Note [Streaming backwards].
727   // Syncs caller_current_stream with leaf streams, so final_callbacks may use
728   // any grad on its device's current stream.
729   if (!leaf_streams.empty()) {
730     for (const auto& leaf_stream : leaf_streams) {
731       // stash_current_cuda/privateuse1_streams() stashed streams for all device
732       // IDs that already had a CUDA/privateuse1 context before the GraphTask
733       // executed. For inactive devices, it stashed a std::nullopt. I don't
734       // expect GraphTask's backward pass ran leaf nodes on any new devices, so
735       // the stashed streams should be enough. If leaf_stream.device_index()
736       // happens to be for a new device, operator* on the std::nullopt should
737       // throw an error.
738       const auto caller_current_stream =
739           // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
740           *caller_current_streams_[leaf_stream.device_index()];
741 
742       if (caller_current_stream != leaf_stream) {
743         auto event = c10::Event{leaf_stream.device_type()};
744         event.record(leaf_stream);
745         caller_current_stream.wait(event);
746       }
747     }
748 
749     caller_current_streams_filtered.reserve(caller_current_streams_.size());
750     for (const auto& opt_stream : caller_current_streams_) {
751       if (opt_stream.has_value()) {
752         caller_current_streams_filtered.push_back(*opt_stream);
753       }
754     }
755   }
756 
757   {
758     // final_callbacks run on the per-device caller_current_streams (the ambient
759     // streams surrounding the user's call to backward()). This has two
760     // benefits:
761     //  1. caller_current_streams have been synced with leaf_streams, so
762     //  callbacks may
763     //     safely access any grad.
764     //  2. The callback's results can safely be used on (user-facing)
765     //  caller_current_streams
766     //     after backward().
767     c10::MultiStreamGuard g(caller_current_streams_filtered);
768 
769     // Set the ThreadLocalState before calling the function.
770     // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
771     // always saves ThreadLocalState without grad_mode.
772     at::ThreadLocalStateGuard tls_guard(this->thread_locals_);
773 
774     // WARNING: Don't use a range-for loop here because more callbacks may be
775     // added in between callback calls, so iterators may become invalidated.
776     // NOLINTNEXTLINE(modernize-loop-convert)
777     for (size_t i = 0; i < final_callbacks_.size(); ++i) {
778       cb_lock.unlock();
779       final_callbacks_[i]();
780       cb_lock.lock();
781     }
782   }
783 }
784 
set_exception_without_signal(const std::shared_ptr<Node> & fn)785 void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
786   if (!has_error_.exchange(true)) {
787     if (AnomalyMode::is_enabled() && fn) {
788       fn->metadata()->print_stack(fn->name());
789     }
790   }
791 }
792 
set_exception(std::exception_ptr eptr,const std::shared_ptr<Node> & fn)793 void GraphTask::set_exception(
794     std::exception_ptr eptr,
795     const std::shared_ptr<Node>& fn) {
796   set_exception_without_signal(fn);
797   if (!future_completed_.exchange(true)) {
798     future_result_->setError(std::move(eptr));
799   }
800 }
801 
call_pre_hooks(Node & fn,variable_list inputs)802 static variable_list call_pre_hooks(Node& fn, variable_list inputs) {
803   for (const auto& hook : fn.pre_hooks()) {
804     inputs = (*hook)(inputs);
805   }
806   return inputs;
807 }
808 
call_tensor_pre_hooks(Node & fn,variable_list inputs)809 static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) {
810   for (const auto& hook : fn.tensor_pre_hooks()) {
811     inputs = (*hook)(inputs);
812   }
813   for (const auto& pair : fn.retains_grad_hooks()) {
814     inputs = (*pair.second)(inputs);
815   }
816   return inputs;
817 }
818 
call_post_hooks(Node & fn,variable_list outputs,const variable_list & inputs,const bool had_post_hooks)819 static variable_list call_post_hooks(
820     Node& fn,
821     variable_list outputs,
822     const variable_list& inputs,
823     const bool had_post_hooks) {
824   for (const auto& hook : fn.post_hooks()) {
825     if (had_post_hooks) {
826       outputs = (*hook)(outputs, inputs);
827     } else {
828       variable_list null_inputs;
829       outputs = (*hook)(outputs, null_inputs);
830     }
831   }
832   return outputs;
833 }
834 
set_device(int device)835 void set_device(int device) {
836   // NB: We MUST NOT construct the guard for device CPU,
837   // as in some settings we compile with cuda, but
838   // have lazy stubs for CUDA functionality (so actually
839   // attempting to setup a guard(CPU_DEVICE) will cause an
840   // error, because it will still query GetDevice).
841   //
842   // Don't use DeviceGuard here because its destructor may be called before the
843   // device is reset. This is fine because the device is thread local.
844   if (device != CPU_DEVICE) {
845     for (const auto i : c10::irange(static_cast<size_t>(
846              c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES))) {
847       auto* impl = c10::impl::device_guard_impl_registry[i].load();
848       if (impl && device < impl->deviceCount()) {
849         impl->setDevice(at::Device(
850             static_cast<c10::DeviceType>(i),
851             static_cast<c10::DeviceIndex>(device)));
852       }
853     }
854   }
855   worker_device = device;
856 }
857 
validate_outputs(const edge_list & edges,variable_list & grads,const std::function<std::string (const std::string &)> & format_error)858 void validate_outputs(
859     const edge_list& edges,
860     variable_list& grads,
861     const std::function<std::string(const std::string&)>& format_error) {
862   if (grads.size() != edges.size()) {
863     std::stringstream ss;
864     ss << "invalid number of gradients - expected ";
865     ss << edges.size() << ", but got " << grads.size();
866     TORCH_CHECK(false, format_error(ss.str()));
867   }
868   for (const auto i : c10::irange(grads.size())) {
869     const auto& edge = edges[i];
870     if (!edge.is_valid())
871       continue;
872 
873     const auto& metadata = edge.function->input_metadata(edge.input_nr);
874     auto& grad = grads[i];
875     if (!grad.defined()) {
876       // FIXME: TestJit.test_ge_optimized fails this assertion.
877       // std::stringstream ss;
878       // ss << "undefined gradient at index " << i;
879       // TORCH_CHECK(false, format_error(ss.str()));
880       continue;
881     }
882 
883     grad = metadata.maybe_reduce(i, std::move(grad), format_error);
884 
885     bool input_is_complex =
886         isComplexType(c10::typeMetaToScalarType(metadata.options().dtype()));
887     bool grad_is_complex = isComplexType(grad.scalar_type());
888 
889     TORCH_CHECK(
890         isFloatingType(grad.scalar_type()) ||
891         (input_is_complex == grad_is_complex));
892     if (c10::typeMetaToScalarType(metadata.options().dtype()) !=
893         grad.scalar_type()) {
894       grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
895     }
896     if (grad.dtype() != metadata.dtype()) {
897       std::stringstream ss;
898       ss << "invalid gradient at index " << i << " - expected dtype ";
899       ss << metadata.dtype() << " but got " << grad.dtype();
900       TORCH_CHECK(false, format_error(ss.str()));
901     }
902     if (grad.layout() != metadata.layout()) {
903       // TODO: Currently we only support (*, Sparse) combination for
904       // (tensor.layout(), tensor.grad.layout()) In future, there will be an
905       // opportunity to support more combinations of layouts if they are
906       // composable (example., operations like addition etc., are well defined
907       // between tensors of different layouts.), as well as all parts of
908       // autograd like AccumulateGrad correctly handle this. We allow grad to be
909       // Strided when metadata is SparseCsr
910       if (!grad.is_sparse() &&
911           !(grad.layout() == at::kStrided &&
912             (at::sparse_csr::is_sparse_compressed(metadata.layout()) ||
913              metadata.layout() == at::kSparse))) {
914         std::stringstream ss;
915         ss << "invalid gradient at index " << i << " - expected layout ";
916         ss << metadata.layout() << " but got " << grad.layout();
917         TORCH_CHECK(false, format_error(ss.str()));
918       }
919     }
920 
921     if (grad.device() != metadata.device()) {
922       // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
923       // should be eventually removed
924       if (!(metadata.is_tensor_subclass() ||
925             grad.unsafeGetTensorImpl()->is_python_dispatch())) {
926         if (grad.dim() == 0) {
927           grad = grad.to(metadata.device());
928         } else {
929           std::stringstream ss;
930           ss << "invalid gradient at index " << i << " - expected device ";
931           ss << metadata.device() << " but got " << grad.device();
932           TORCH_CHECK(false, format_error(ss.str()));
933         }
934       }
935     }
936     // We should not build graph for Tensors that are not differentiable
937     TORCH_INTERNAL_ASSERT(isDifferentiableType(grad.scalar_type()));
938   }
939 }
940 
call_function(std::shared_ptr<GraphTask> & graph_task,Node * func,InputBuffer & inputBuffer)941 static variable_list call_function(
942     std::shared_ptr<GraphTask>& graph_task,
943     Node* func,
944     InputBuffer& inputBuffer) {
945   CheckpointValidGuard cpvguard(graph_task);
946   auto& fn = *func;
947   auto inputs =
948       call_tensor_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
949   inputs = call_pre_hooks(fn, std::move(inputs));
950   if (!graph_task->keep_graph_) {
951     fn.will_release_variables();
952   }
953 
954   const auto has_post_hooks = !fn.post_hooks().empty();
955   variable_list outputs;
956 
957   if (has_post_hooks) {
958     // In functions/accumulate_grad.cpp, there is some logic to check the
959     // conditions under which the incoming gradient can be stolen directly
960     // (which elides a deep copy) instead of cloned. One of these conditions
961     // is that the incoming gradient's refcount must be 1 (nothing else is
962     // referencing the same data).  Stashing inputs_copy here bumps the
963     // refcount, so if post hooks are employed, it's actually still ok for
964     // accumulate_grad.cpp to steal the gradient if the refcount is 2.
965     //
966     // "new_grad.use_count() <= 1 + !post_hooks().empty()" in
967     // accumulate_grad.cpp accounts for this, but also creates a silent
968     // dependency between engine.cpp (ie, this particular engine
969     // implementation) and accumulate_grad.cpp.
970     //
971     // If you change the logic here, make sure it's compatible with
972     // accumulate_grad.cpp.
973     auto inputs_copy = inputs;
974     outputs = fn(std::move(inputs_copy));
975   } else {
976     outputs = fn(std::move(inputs));
977   }
978 
979   validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
980     std::ostringstream ss;
981     ss << "Function " << fn.name() << " returned an " << msg;
982     return ss.str();
983   });
984 
985   // NOLINTNEXTLINE(bugprone-use-after-move)
986   return call_post_hooks(fn, std::move(outputs), inputs, has_post_hooks);
987 }
988 
evaluate_function(std::shared_ptr<GraphTask> & graph_task,Node * func,InputBuffer & inputs,const std::shared_ptr<ReadyQueue> & cpu_ready_queue)989 void Engine::evaluate_function(
990     std::shared_ptr<GraphTask>& graph_task,
991     Node* func,
992     InputBuffer& inputs,
993     const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
994   // The InputBuffer::adds that supplied incoming grads took pains to
995   // ensure they're safe to consume in the context of the present
996   // func's stream (if applicable). So we guard onto that stream
997   // before working with the grads in any capacity.
998   auto opt_parent_stream = (*func).stream();
999   c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
1000 
1001   // If exec_info_ is not empty, we have to instrument the execution
1002   auto& exec_info_ = graph_task->exec_info_;
1003   if (!exec_info_.empty()) {
1004     auto& fn_info = exec_info_.at(func);
1005     variable_list new_inputs = inputs.buffer;
1006     if (!fn_info.needed_) {
1007       // We always want to call tensor pre-hooks, but want to avoid calling it
1008       // twice. needed_ = True indicates that we will call tensor pre-hooks
1009       // later.
1010       //
1011       // See NOTE [Hooks ordering] for more context.
1012       new_inputs = call_tensor_pre_hooks(
1013           *func, InputBuffer::variables(std::move(inputs)));
1014     }
1015     if (auto* capture_vec = fn_info.captures_.get()) {
1016       auto opt_parent_stream = (*func).stream();
1017       // Lock mutex for writing to graph_task->captured_vars_.
1018       std::lock_guard<std::mutex> lock(graph_task->mutex_);
1019       for (const auto& capture : *capture_vec) {
1020         auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
1021         captured_grad = new_inputs[capture.input_idx_];
1022         // NOTE [Deprecated capture hooks]
1023         for (const auto& hook :
1024              capture.DO_NOT_USE_DEPRECATED_get_capture_hooks()) {
1025           captured_grad = (*hook)(captured_grad);
1026         }
1027         if (opt_parent_stream) {
1028           // No need to take graph_task->mutex_ here, we already hold it
1029           graph_task->leaf_streams.emplace(*opt_parent_stream);
1030         }
1031       }
1032     }
1033     if (!fn_info.needed_) {
1034       // Skip execution if we don't need to execute the function.
1035       return;
1036     }
1037   }
1038 
1039   auto outputs = call_function(graph_task, func, inputs);
1040 
1041   auto& fn = *func;
1042   if (!graph_task->keep_graph_) {
1043     fn.release_variables();
1044   }
1045 
1046   auto num_outputs = outputs.size();
1047   if (num_outputs == 0) { // Note: doesn't acquire the mutex
1048     // Records leaf stream (if applicable)
1049     // See Note [Streaming backwards]
1050     if (opt_parent_stream) {
1051       std::lock_guard<std::mutex> lock(graph_task->mutex_);
1052       graph_task->leaf_streams.emplace(*opt_parent_stream);
1053     }
1054     return;
1055   }
1056 
1057   if (AnomalyMode::is_enabled() && AnomalyMode::should_check_nan()) {
1058     AutoGradMode grad_mode(false);
1059     for (const auto i : c10::irange(num_outputs)) {
1060       auto& output = outputs[i];
1061       at::OptionalDeviceGuard guard(device_of(output));
1062       if (output.defined() && isnan(output)._is_any_true().item<bool>()) {
1063         std::stringstream ss;
1064         ss << "Function '" << fn.name() << "' returned nan values in its " << i
1065            << "th output.";
1066         throw std::runtime_error(ss.str());
1067       }
1068     }
1069   }
1070 
1071   // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and
1072   // cpu_ready_queue_ below
1073   std::lock_guard<std::mutex> lock(graph_task->mutex_);
1074   for (const auto i : c10::irange(num_outputs)) {
1075     auto& output = outputs[i];
1076     const auto& next = fn.next_edge(i);
1077 
1078     if (!next.is_valid())
1079       continue;
1080 
1081     // Check if the next function is ready to be computed
1082     bool is_ready = false;
1083     auto& dependencies = graph_task->dependencies_;
1084     auto it = dependencies.find(next.function.get());
1085 
1086     if (it == dependencies.end()) {
1087       auto name = next.function->name();
1088       throw std::runtime_error(std::string("dependency not found for ") + name);
1089     } else if (--it->second == 0) {
1090       dependencies.erase(it);
1091       is_ready = true;
1092     }
1093 
1094     auto& not_ready = graph_task->not_ready_;
1095     auto not_ready_it = not_ready.find(next.function.get());
1096     if (not_ready_it == not_ready.end()) {
1097       // Skip functions that aren't supposed to be executed
1098       if (!exec_info_.empty()) {
1099         auto it = exec_info_.find(next.function.get());
1100         if (it == exec_info_.end() || !it->second.should_execute()) {
1101           continue;
1102         }
1103       }
1104       // No buffers have been allocated for the function
1105       InputBuffer input_buffer(next.function->num_inputs());
1106 
1107       // Accumulates into buffer
1108       auto opt_next_stream = next.function->stream();
1109       input_buffer.add(
1110           next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
1111 
1112       if (is_ready) {
1113         auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
1114         queue->push(
1115             NodeTask(graph_task, next.function, std::move(input_buffer)));
1116       } else {
1117         not_ready.emplace(next.function.get(), std::move(input_buffer));
1118       }
1119     } else {
1120       // The function already has a buffer
1121       auto& input_buffer = not_ready_it->second;
1122 
1123       // Accumulates into buffer
1124       auto opt_next_stream = next.function->stream();
1125       input_buffer.add(
1126           next.input_nr, std::move(output), opt_parent_stream, opt_next_stream);
1127       if (is_ready) {
1128         auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
1129         queue->push(
1130             NodeTask(graph_task, next.function, std::move(input_buffer)));
1131         not_ready.erase(not_ready_it);
1132       }
1133     }
1134   }
1135 }
1136 
compute_min_topological_nr(const edge_list & outputs)1137 inline static uint64_t compute_min_topological_nr(const edge_list& outputs) {
1138   // Computes the mininum topological number among all the outputs
1139   if (outputs.empty()) {
1140     return 0;
1141   }
1142   auto min_topo_nr = std::numeric_limits<uint64_t>::max();
1143   for (auto& output_edge : outputs) {
1144     auto topo_nr = output_edge.function->topological_nr();
1145     min_topo_nr = (min_topo_nr < topo_nr) ? min_topo_nr : topo_nr;
1146   }
1147   return min_topo_nr;
1148 }
1149 
compute_dependencies(Node * root,GraphTask & task,uint64_t min_topo_nr)1150 auto Engine::compute_dependencies(
1151     Node* root,
1152     GraphTask& task,
1153     uint64_t min_topo_nr) -> void {
1154   // Computes the number of dependencies for each function which requires grad
1155   std::vector<Node*> queue{root};
1156   bool will_use_accelerator = false;
1157 
1158   // Queue contains all nodes that will start propagating gradients.
1159   // We no longer have to expand functions that don't require grad.
1160   auto& dependencies = task.dependencies_;
1161   while (!queue.empty()) {
1162     auto fn = queue.back();
1163     queue.pop_back();
1164     if (fn->topological_nr() < min_topo_nr) {
1165       continue;
1166     }
1167     if (!will_use_accelerator) {
1168       will_use_accelerator = fn->stream().has_value();
1169     }
1170     for (const auto& edge : fn->next_edges()) {
1171       if (auto next_ptr = edge.function.get()) {
1172         dependencies[next_ptr] += 1;
1173         const bool was_inserted = task.nodes_in_graph_.insert(next_ptr).second;
1174         if (was_inserted)
1175           queue.push_back(next_ptr);
1176       }
1177     }
1178   }
1179 
1180   if (will_use_accelerator) {
1181     // Collects current streams for devices where this process has a
1182     // context, so GraphTask::exec_post_processing can sync them with
1183     // leaf_streams.
1184     task.stash_current_streams();
1185   }
1186 }
1187 
execute(const edge_list & root_edges,const variable_list & inputs,bool keep_graph,bool create_graph,bool accumulate_grad,const edge_list & outputs)1188 auto Engine::execute(
1189     const edge_list& root_edges,
1190     const variable_list& inputs,
1191     bool keep_graph,
1192     bool create_graph,
1193     bool accumulate_grad,
1194     const edge_list& outputs) -> variable_list {
1195   validate_outputs(
1196       root_edges,
1197       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1198       const_cast<variable_list&>(inputs),
1199       [](const std::string& msg) { return msg; });
1200   if (accumulate_grad && create_graph) {
1201     TORCH_WARN_ONCE(
1202         "Using backward() with create_graph=True will create a reference cycle "
1203         "between the parameter and its gradient which can cause a memory leak. "
1204         "We recommend using autograd.grad when creating the graph to avoid this. "
1205         "If you have to use this function, make sure to reset the .grad fields of "
1206         "your parameters to None after use to break the cycle and avoid the leak.");
1207   }
1208 
1209   // Allows us to assert no other threads are in backwards
1210   CompiledAutogradThreadingDebugCheck _thread_check;
1211   auto compiled_autograd = the_compiled_autograd.load();
1212   TORCH_INTERNAL_ASSERT(compiled_autograd != COMPILED_AUTOGRAD_POISON);
1213 
1214   // accumulate_grad is true if and only if the frontend call was to
1215   // backward(), not grad(). grad() returns the sum of the gradients
1216   // w.r.t. the inputs and thus needs the inputs to be present.
1217   TORCH_CHECK_VALUE(
1218       accumulate_grad || !outputs.empty(), "grad requires non-empty inputs.");
1219 
1220   // A fresh first time Engine::execute call should start on the CPU device,
1221   // initialize a new thread local ready queue on CPU or reuse the existing one
1222   // (if there is one allocated already, i.e. consecutive backward calls,
1223   // re-entrant backward calls), then memoize the local_ready_queue in GraphTask
1224   init_local_ready_queue();
1225   bool not_reentrant_backward_call = worker_device == NO_DEVICE;
1226 
1227   // Store root nodes so we can traverse through the graph later
1228   // e.g., for get_current_graph_task_execution_order
1229   c10::SmallVector<Node*, 4> temp_roots{root_edges.size()};
1230   for (const auto i : c10::irange(root_edges.size())) {
1231     temp_roots[i] = root_edges[i].function.get();
1232   }
1233 
1234   auto graph_task = std::make_shared<GraphTask>(
1235       /* keep_graph */ keep_graph,
1236       /* create_graph */ create_graph,
1237       /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
1238       /* cpu_ready_queue */ local_ready_queue,
1239       /* graph_roots */ std::move(temp_roots));
1240 
1241   // If we receive a single root, skip creating extra root node
1242   bool skip_dummy_node = root_edges.size() == 1 && compiled_autograd == nullptr;
1243   auto graph_root = skip_dummy_node
1244       ? root_edges.at(0).function
1245       : std::make_shared<GraphRoot>(root_edges, inputs);
1246 
1247   auto min_topo_nr = compute_min_topological_nr(outputs);
1248   // Now compute the dependencies for all executable functions
1249   compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);
1250 
1251   if (!outputs.empty()) {
1252     graph_task->init_to_execute(
1253         *graph_root, outputs, accumulate_grad, min_topo_nr);
1254   }
1255 
1256   if (compiled_autograd != nullptr) {
1257     // see [Note: Compiled Autograd]
1258     TORCH_CHECK(
1259         !create_graph, "compiled_autograd does not support create_graph");
1260     _thread_check.release();
1261     TORCH_CHECK(
1262         !AnomalyMode::is_enabled(),
1263         "compiled_autograd does not support AnomalyMode")
1264     return (*compiled_autograd)(
1265         graph_root, *graph_task, accumulate_grad, outputs);
1266   }
1267 
1268   // Queue the root
1269   if (skip_dummy_node) {
1270     InputBuffer input_buffer(root_edges.at(0).function->num_inputs());
1271     auto input = inputs.at(0);
1272 
1273     const auto input_stream = InputMetadata(input).stream();
1274     auto opt_next_stream = root_edges.at(0).function->stream();
1275     input_buffer.add(
1276         root_edges.at(0).input_nr,
1277         std::move(input),
1278         input_stream,
1279         opt_next_stream);
1280 
1281     execute_with_graph_task(
1282         graph_task, std::move(graph_root), std::move(input_buffer));
1283   } else {
1284     execute_with_graph_task(
1285         graph_task, std::move(graph_root), InputBuffer(variable_list()));
1286   }
1287   // Avoid a refcount bump for the Future, since we check for refcount in
1288   // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1)
1289   // in dist_engine.cpp).
1290   auto& fut = graph_task->future_result_;
1291   fut->wait();
1292   graph_task->warning_handler_.replay_warnings();
1293   return fut->value().toTensorVector();
1294 }
1295 
initialize_device_threads_pool()1296 void Engine::initialize_device_threads_pool() {
1297   TORCH_CHECK(
1298       !in_bad_autograd_fork,
1299       "Unable to handle autograd's threading in combination with fork-based multiprocessing. "
1300       "See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork");
1301   c10::call_once(
1302       start_device_threads_flag_, &Engine::start_device_threads, this);
1303 }
1304 
execute_with_graph_task(const std::shared_ptr<GraphTask> & graph_task,std::shared_ptr<Node> graph_root,InputBuffer && input_buffer)1305 c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
1306     const std::shared_ptr<GraphTask>& graph_task,
1307     std::shared_ptr<Node> graph_root,
1308     InputBuffer&& input_buffer) {
1309   initialize_device_threads_pool();
1310   // Lock mutex for GraphTask.
1311   std::unique_lock<std::mutex> lock(graph_task->mutex_);
1312 
1313   auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());
1314 
1315   // worker_device == NO_DEVICE it's a CPU thread and it's trying to drive the
1316   // autograd engine with corresponding GraphTask, and its NOT a re-entrant call
1317   if (worker_device == NO_DEVICE) {
1318     // We set the worker_device to CPU_DEVICE only if worker_device was
1319     // previously NO_DEVICE. Setting it to CPU afterwards allow us to detect
1320     // whether this is a re-entrant call or not.
1321     set_device(CPU_DEVICE);
1322 
1323     // set the graph_task owner to the current device
1324     graph_task->owner_ = worker_device;
1325 
1326     // Now that all the non-thread safe fields of the graph_task have been
1327     // populated, we can enqueue it.
1328     queue->push(
1329         NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
1330 
1331     // The owning thread start to drive the engine execution for any CPU task
1332     // that was just pushed or will be added later from other worker threads
1333     lock.unlock();
1334     thread_main(graph_task);
1335     TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
1336     // reset the worker_device after the completion of the graph_task, this is
1337     // so that the initial state of the engine remains the same across every
1338     // backward() or grad() call, we don't need to reset local_ready_queue as we
1339     // could possibly reuse it for new backward calls.
1340     worker_device = NO_DEVICE;
1341   } else {
1342     // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
1343     //    backward call from that device.
1344     graph_task->owner_ = worker_device;
1345 
1346     // Now that all the non-thread safe fields of the graph_task have been
1347     // populated, we can enqueue it.
1348     queue->push(
1349         NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
1350 
1351     if (current_depth >= max_recursion_depth_) {
1352       // See Note [Reentrant backwards]
1353       // If reached the max depth, switch to a different thread
1354       add_thread_pool_task(graph_task);
1355     } else {
1356       // Total depth needs to be updated only in this codepath, since it is
1357       // not used in the block above (when we call add_thread_pool_task).
1358       // In the codepath above, GraphTask.reentrant_depth_ is used to
1359       // bootstrap total_depth in the other thread.
1360       ++total_depth;
1361 
1362       // Get back to work while we wait for our new graph_task to
1363       // complete!
1364       ++current_depth;
1365       lock.unlock();
1366       thread_main(graph_task);
1367       --current_depth;
1368       --total_depth;
1369 
1370       // The graph task should have completed and the associated future should
1371       // be marked completed as well since 'thread_main' above is a call
1372       // blocking an autograd engine thread.
1373       TORCH_INTERNAL_ASSERT(graph_task->future_result_->completed());
1374     }
1375   }
1376   // graph_task_exec_post_processing is done when the Future is marked as
1377   // completed in mark_as_completed_and_run_post_processing.
1378   return graph_task->future_result_;
1379 }
1380 
1381 // note that when python is present, this base engine will be overriden
1382 // with a PythonEngine. Because this typically happens before get_default_engine
1383 // is called, this base engine will never be created.
get_base_engine()1384 Engine& Engine::get_base_engine() {
1385   static Engine engine;
1386   return engine;
1387 }
1388 
1389 std::atomic<EngineStub> engine_stub(Engine::get_base_engine);
1390 
set_default_engine_stub(EngineStub stub)1391 void set_default_engine_stub(EngineStub stub) {
1392   engine_stub.store(stub);
1393 }
1394 
get_default_engine()1395 Engine& Engine::get_default_engine() {
1396   return engine_stub.load()();
1397 }
1398 
set_compiled_autograd(Engine::compiled_autograd_fn fn)1399 void Engine::set_compiled_autograd(Engine::compiled_autograd_fn fn) {
1400   if (the_compiled_autograd.load() == fn) {
1401     return;
1402   }
1403   auto prior = the_compiled_autograd.exchange(COMPILED_AUTOGRAD_POISON);
1404   TORCH_CHECK(
1405       num_threads_in_backwards.load() == 0 && prior != COMPILED_AUTOGRAD_POISON,
1406       "compiled_autograd.enable() requires no threads in backwards()");
1407   the_compiled_autograd.store(fn);
1408 }
1409 
queue_callback(std::function<void ()> callback)1410 void Engine::queue_callback(std::function<void()> callback) {
1411   TORCH_CHECK(
1412       current_graph_task,
1413       "Final callbacks can only be installed during backward pass.");
1414 
1415   std::lock_guard<std::mutex> lock(current_graph_task->final_callbacks_lock_);
1416   current_graph_task->final_callbacks_.emplace_back(std::move(callback));
1417 }
1418 
is_checkpoint_valid()1419 bool Engine::is_checkpoint_valid() {
1420   return checkpoint_valid;
1421 }
1422 
init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue)1423 void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
1424   if (ready_queue) {
1425     // if ready_queue provided in the caller, use the caller's ready_queue to
1426     // initialize local_ready_queue
1427     local_ready_queue = std::move(ready_queue);
1428   } else if (!local_ready_queue) {
1429     // otherwise if local_ready_queue not allocated, allocate a new ready_queue
1430     local_ready_queue = std::make_shared<ReadyQueue>();
1431   }
1432 }
1433 
1434 // CPU ready queue is per GraphTask, but CUDA device ready queues are shared
1435 // across all graph tasks
ready_queue(std::shared_ptr<ReadyQueue> cpu_ready_queue,at::Device device)1436 auto Engine::ready_queue(
1437     std::shared_ptr<ReadyQueue> cpu_ready_queue,
1438     at::Device device) -> std::shared_ptr<ReadyQueue> {
1439   bool multithreading_disabled =
1440       !c10::AutogradState::get_tls_state().get_multithreading_enabled();
1441   if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) {
1442     // return the cpu ready queue passed in
1443     TORCH_INTERNAL_ASSERT(cpu_ready_queue);
1444     return cpu_ready_queue;
1445   } else {
1446     TORCH_INTERNAL_ASSERT(
1447         0 <= device.index() &&
1448         device.index() <
1449             static_cast<c10::DeviceIndex>(device_ready_queues_.size()));
1450     // See Note [Allocating GPUs to autograd threads]
1451     return device_ready_queues_.at(device.index());
1452   }
1453 }
1454 
ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue,int device_index)1455 auto Engine::ready_queue_by_index(
1456     std::shared_ptr<ReadyQueue> cpu_ready_queue,
1457     int device_index) -> std::shared_ptr<ReadyQueue> {
1458   if (device_index == CPU_DEVICE) {
1459     // return the cpu ready queue passed in
1460     TORCH_INTERNAL_ASSERT(cpu_ready_queue);
1461     return cpu_ready_queue;
1462   } else {
1463     TORCH_INTERNAL_ASSERT(
1464         0 <= device_index &&
1465         device_index <
1466             static_cast<c10::DeviceIndex>(device_ready_queues_.size()));
1467     // See Note [Allocating GPUs to autograd threads]
1468     // NB: This function would become obsolete if we truly allocated a CPU
1469     // thread per device, rather than colocate.
1470     return device_ready_queues_.at(device_index);
1471   }
1472 }
1473 
start_device_threads()1474 auto Engine::start_device_threads() -> void {
1475   // First always initialize the thread pool for re-entrant threads
1476   thread_pool_shared_ = std::make_shared<ThreadPoolShared>();
1477 
1478   // Second, create special threads for each non-CPU device
1479   // See Note [Allocating GPUs to autograd threads]
1480   c10::DeviceIndex num_devices = 0;
1481   for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
1482     auto* impl = impl_atomic.load();
1483     // Only record the number of devices for device that don't run on the
1484     // cpu ready queue.
1485     if (impl && !should_run_in_cpu_ready_queue(impl->type())) {
1486       num_devices = std::max(num_devices, impl->deviceCount());
1487     }
1488   }
1489 
1490   // If there are no device except cpu, no need to create worker threads
1491   if (num_devices == 0) {
1492     return;
1493   }
1494 
1495   // Since we're about to create threads, forking is not possible anymore
1496   track_bad_autograd_forks();
1497 
1498   // allocate one thread for every GPU device (but colocate GPUs of different
1499   // types), and pre-allocate the device_ready_queues_ to ensure safe reading on
1500   // it.
1501   device_ready_queues_ = std::vector<std::shared_ptr<ReadyQueue>>(num_devices);
1502   for (auto& queue : device_ready_queues_) {
1503     queue = std::make_shared<ReadyQueue>();
1504   }
1505 
1506   for (const auto i : c10::irange(num_devices)) {
1507     std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true);
1508     t.detach();
1509   }
1510   // Wait for the threads to start
1511   {
1512     std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
1513     while (non_reentrant_device_thread_count_.load() !=
1514            static_cast<uint32_t>(num_devices)) {
1515       non_reentrant_device_thread_condvar_.wait(lk);
1516     }
1517   }
1518 }
1519 
add_thread_pool_task(const std::weak_ptr<GraphTask> & graph_task)1520 void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
1521   std::unique_lock<std::mutex> lck(thread_pool_shared_->mutex_);
1522   // There may already be some items on the graphtasks_queue_ added by other
1523   // threads but not enough workers to get to the new task that will be
1524   // added
1525   bool create_thread =
1526       (thread_pool_shared_->num_workers_ <=
1527        thread_pool_shared_->graphtasks_queue_.size());
1528   thread_pool_shared_->graphtasks_queue_.push(graph_task);
1529   // Don't need to be holding the lock while actually creating the thread
1530   lck.unlock();
1531   if (create_thread) {
1532     // If we're creating a new thread, forking is not allowed anymore
1533     track_bad_autograd_forks();
1534     std::thread t(&Engine::reentrant_thread_init, this);
1535     t.detach();
1536   }
1537   // This works even if new thread is created because wait() will test the
1538   // predicate before waiting
1539   thread_pool_shared_->work_.notify_one();
1540 }
1541 
1542 // Remembers current streams on all devices where a context has been created for
1543 // This function assumes the accelerator device is available.
stash_current_streams()1544 void GraphTask::stash_current_streams() {
1545   const auto accelerator = at::getAccelerator(true).value();
1546   const auto guard = c10::impl::VirtualGuardImpl{accelerator};
1547   auto num_devices = guard.deviceCount();
1548   caller_current_streams_.resize(num_devices);
1549   if (num_devices > 0) {
1550     for (c10::DeviceIndex idx = 0; idx < num_devices; idx++) {
1551       if (at::globalContext().getAcceleratorHooksInterface().hasPrimaryContext(
1552               idx)) {
1553         caller_current_streams_[idx] = guard.getStream({accelerator, idx});
1554       } else {
1555         caller_current_streams_[idx] = std::nullopt;
1556       }
1557     }
1558   }
1559 }
1560 
init_to_execute(Node & graph_root,const edge_list & outputs,bool accumulate_grad,uint64_t min_topo_nr)1561 void GraphTask::init_to_execute(
1562     Node& graph_root,
1563     const edge_list& outputs,
1564     bool accumulate_grad,
1565     uint64_t min_topo_nr) {
1566   // Populates exec_info so nodes that should be executed have
1567   // `exec_info[node].needed_ = true` Only nodes that have a path to any edge in
1568   // `outputs` should be executed. The code below populates exec_info using
1569   // recursion, but the actual code does this iteratively. Refer to the
1570   // numbering to see how the actual code corresponds. A difference to note is
1571   // that in the iterative version, when you are working with the current Node,
1572   // you are responsible to update your parent's is_needed after all your
1573   // children have been updated.
1574   //
1575   // is_needed = {fn: True for fn in outputs}             # (0)
1576   // seen = {}
1577   // def compute_is_needed(fn):
1578   //   for next_edge in fn.next_edges:
1579   //     child_fn = next_edge.fn
1580   //     if child_fn in seen and is_needed[child_fn]:     # (1)
1581   //       is_needed[fn] = true
1582   //     else:
1583   //       seen.add(child_fn)
1584   //       if compute_is_needed(child_fn):
1585   //         is_needed[fn] = true                         # (2)
1586   //                                                      # (3) exit for-loop
1587   //   return is_needed[fn]
1588   // compute_is_needed(graph_root)
1589   //
1590   // NB: you might be wondering why we don't populate `seen` with outputs. We
1591   // cannot because in the case where two outputs lie on the same path, we still
1592   // need to explore past the first output or we would miss the nodes that are
1593   // required to compute the second output.
1594   int output_idx = 0;
1595   for (auto& output_edge : outputs) {
1596     // (0) `is_needed` above corresponds to `exec_info_[fn].needed_`
1597     Node* output = output_edge.function.get();
1598     auto& info = exec_info_[output];
1599     if (accumulate_grad) {
1600       // if called through `.backward()` we directly set `needed_` for all the
1601       // outputs to true
1602       info.needed_ = true;
1603     } else {
1604       // otherwise it is `.grad()` and we set exec_info[fn].captures_ instead
1605       // In terms of populating the rest of exec_info though, you can basically
1606       // think of this as the same as setting `needed_` is true directly.
1607       if (!info.captures_) {
1608         info.captures_ = std::make_unique<std::vector<ExecInfo::Capture>>();
1609       }
1610       info.captures_->emplace_back(output_edge.input_nr, output_idx++);
1611     }
1612   }
1613   captured_vars_.resize(output_idx);
1614 
1615   struct Frame {
1616     Frame(Node* fn) : fn_(fn) {}
1617     Node* fn_{};
1618     size_t next_next_fn_{};
1619 
1620     Node* get_next_fn() {
1621       const auto& next = fn_->next_edges();
1622       auto num_next = next.size();
1623       while (next_next_fn_ < num_next) {
1624         auto fn = next[next_next_fn_++].function.get();
1625         if (fn)
1626           return fn;
1627       }
1628       return nullptr;
1629     }
1630   };
1631 
1632   auto nodeShouldExecute = [this](Node* fn) {
1633     auto it = exec_info_.find(fn);
1634     return it != exec_info_.end() && it->second.should_execute();
1635   };
1636 
1637   std::vector<Frame> stack;
1638   std::unordered_set<Node*> seen;
1639   stack.emplace_back(&graph_root);
1640   exec_info_.emplace(stack.back().fn_, ExecInfo());
1641 
1642   while (!stack.empty()) {
1643     auto& frame = stack.back();
1644     const auto fn = frame.fn_;
1645 
1646     Node* child_fn = nullptr;
1647     while ((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) {
1648       // (1) next child exists AND has already been seen
1649       if (nodeShouldExecute(child_fn)) {
1650         exec_info_[fn].needed_ = true;
1651       }
1652     }
1653 
1654     if (child_fn) {
1655       // (2) next child exists but has not been seen
1656       if (child_fn->topological_nr() < min_topo_nr) {
1657         // child created before the first output means this child cannot have
1658         // an edge to output
1659         continue;
1660       }
1661       stack.emplace_back(child_fn);
1662     } else {
1663       // (3) no next child exists for `fn` means its `needed` has already been
1664       // finalized. pop stack and update parent
1665       stack.pop_back();
1666       if (nodeShouldExecute(fn) && !stack.empty()) {
1667         exec_info_[stack.back().fn_].needed_ = true;
1668       }
1669     }
1670   }
1671 }
1672 
1673 } // namespace torch::autograd
1674