1 #pragma once 2 3 // Engine implements backpropagation from output variables and their gradients 4 // to "root" variables (variables created by the user with requires_grad=True). 5 6 #include <ATen/Tensor.h> 7 #include <ATen/ThreadLocalState.h> 8 #include <ATen/core/ivalue.h> 9 #include <torch/csrc/Export.h> 10 #include <torch/csrc/autograd/anomaly_mode.h> 11 #include <torch/csrc/autograd/function.h> 12 #include <torch/csrc/autograd/functions/basic_ops.h> 13 #include <torch/csrc/autograd/graph_task.h> 14 #include <torch/csrc/autograd/input_buffer.h> 15 #include <torch/csrc/autograd/saved_variable_hooks.h> 16 #include <torch/csrc/autograd/utils/warnings.h> 17 18 #include <c10/util/CallOnce.h> 19 20 #include <exception> 21 #include <functional> 22 #include <memory> 23 #include <queue> 24 #include <utility> 25 #include <vector> 26 27 namespace torch::autograd { 28 struct ReadyQueue; 29 } 30 31 namespace torch::autograd { 32 33 // Maximum reentrant backward depth before switching to a new thread 34 // This limit is based on the TSAN's deadlock detector, where it will 35 // fail if a program hold more than 65 locks in one thread at once. 36 // As we hold mutex in every of our custom C++ autograd Node, we would 37 // like to avoid TSAN complains on this when doing reentrant backwards 38 // For reference, see https://github.com/google/sanitizers/issues/950 39 static constexpr int MAX_DEPTH = 60; 40 41 void set_device(int device); 42 TORCH_API void validate_outputs( 43 const edge_list& edges, 44 variable_list& grads, 45 const std::function<std::string(const std::string&)>& format_error); 46 47 struct NodeTask { 48 std::weak_ptr<GraphTask> base_; 49 std::shared_ptr<Node> fn_; 50 // This buffer serves as an implicit "addition" node for all of the 51 // gradients flowing here. Once all the dependencies are finished, we 52 // use the contents of this buffer to run the function. 53 InputBuffer inputs_; 54 // When worker receives a task with isShutdownTask = true, it will immediately 55 // exit. The engine sends a shutdown task to every queue upon its destruction. 56 bool isShutdownTask_; 57 58 int getReentrantDepth() const; 59 60 NodeTask( 61 std::weak_ptr<GraphTask> base, 62 std::shared_ptr<Node> fn, 63 InputBuffer inputs, 64 bool isShutdownTask = false) base_NodeTask65 : base_(std::move(base)), 66 fn_(std::move(fn)), 67 inputs_(std::move(inputs)), 68 isShutdownTask_(isShutdownTask) {} 69 }; 70 71 // Guard that sets and restores checkpoint_valid 72 class CheckpointValidGuard { 73 public: 74 explicit CheckpointValidGuard( 75 const std::shared_ptr<const GraphTask>& graph_task); 76 ~CheckpointValidGuard(); 77 78 private: 79 bool prev_checkpoint_valid_state; 80 }; 81 82 struct ReadyQueue { 83 private: 84 // Returns true when t2 should be (weakly) BEFORE t1 in the queue. 85 // Shutdown tasks are first and then empty NodeTask are next. 86 struct CompareNodeTaskTime { operatorReadyQueue::CompareNodeTaskTime87 bool operator()(NodeTask const& t1, NodeTask const& t2) { 88 // NOLINTNEXTLINE(bugprone-branch-clone) 89 if (t2.isShutdownTask_) { 90 return true; 91 } else if (!t1.fn_ || t1.isShutdownTask_) { 92 return false; 93 } else if (!t2.fn_) { 94 return true; 95 } else if (t1.getReentrantDepth() == t2.getReentrantDepth()) { 96 return t1.fn_->sequence_nr() < t2.fn_->sequence_nr(); 97 } else { 98 return t1.getReentrantDepth() < t2.getReentrantDepth(); 99 } 100 } 101 }; 102 103 // To notify threads waiting on the ReadyQueue of available tasks on the heap_ 104 std::condition_variable not_empty_; 105 // To protect read and writes to heap_ 106 mutable std::mutex mutex_; 107 108 std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime> 109 heap_; 110 111 public: 112 // incrementOutstandingTasks indicates whether or not we should increment 113 // 'outstanding_tasks_' for the associated GraphTask. This should mostly 114 // always be true and is only set false in certain cases (see docs for 115 // DistEngine.execute_graph_task_until_ready_queue_empty) 116 void push(NodeTask item, bool incrementOutstandingTasks = true); 117 void pushShutdownTask(); 118 NodeTask pop(); 119 bool empty() const; 120 size_t size() const; 121 }; 122 123 // A single instance of this struct should be created through the whole process 124 // lifetime. The worker thread creation logic and Engine's destructor rely on 125 // this. 126 struct TORCH_API Engine { 127 /// Returns a reference to a static `Engine` instance. 128 static Engine& get_default_engine(); 129 130 static Engine& get_base_engine(); 131 132 // compiled_autograd needs to live in a different .so file so that it 133 // can have python symbols, so we add a layer of indirection 134 // see [Note: Compiled Autograd] 135 typedef variable_list (*compiled_autograd_fn)( 136 const std::shared_ptr<Node>& graph_root, 137 GraphTask& graph_task, 138 bool accumulate_grad, 139 const edge_list& outputs); 140 static void set_compiled_autograd(compiled_autograd_fn fn); 141 142 Engine(const Engine&) = delete; 143 Engine(Engine&&) = delete; 144 virtual ~Engine(); 145 146 // Given a list of (Node, input number) pairs computes the value of the graph 147 // by following next_edge references. 148 virtual variable_list execute( 149 const edge_list& roots, 150 const variable_list& inputs, 151 bool keep_graph, 152 bool create_graph, 153 bool accumulate_grad, 154 const edge_list& outputs = {}); 155 156 // Given a pre-populated GraphTask and GraphRoot, computes the backward pass 157 // for the graph. 158 // 159 // NB: This API should only be used by internal autograd specific 160 // machinery and shouldn't be exposed to users in anyway. 161 virtual c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task( 162 const std::shared_ptr<GraphTask>& graph_task, 163 std::shared_ptr<Node> graph_root, 164 InputBuffer&& input_buffer); 165 make_anomaly_metadataEngine166 virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() { 167 return std::make_unique<AnomalyMetadata>(); 168 } 169 get_default_saved_variable_hooksEngine170 virtual std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() { 171 return nullptr; 172 } 173 174 // We pass cpu_ready_queue to evaluate_function, so that it knows 175 // the correct ready queue to push to after a NodeTask is ready 176 void evaluate_function( 177 std::shared_ptr<GraphTask>& graph_task, 178 Node* func, 179 InputBuffer& inputs, 180 const std::shared_ptr<ReadyQueue>& cpu_ready_queue); 181 182 void initialize_device_threads_pool(); 183 virtual void thread_on_exception( 184 const std::shared_ptr<GraphTask>& graph_task, 185 const std::shared_ptr<Node>& fn, 186 std::exception& e); 187 188 void queue_callback(std::function<void()> callback); 189 190 bool is_checkpoint_valid(); 191 192 // Should be called after fork to notify that worker threads are gone 193 void release_workers(); 194 195 // Must be called by subclass before destructing to avoid a data-race-on-vptr. 196 void stop(); 197 198 // Initializes a device thread for the autograd engine. 199 virtual void thread_init( 200 int device, 201 const std::shared_ptr<ReadyQueue>& ready_queue, 202 bool should_increment = true); 203 204 protected: 205 Engine(); 206 void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr); 207 208 // initialize the thread local ready queue with the ready queue that is 209 // created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new 210 // ready queue if ready_queue is not provided. 211 void init_local_ready_queue( 212 std::shared_ptr<ReadyQueue> ready_queue = nullptr); 213 214 std::shared_ptr<ReadyQueue> ready_queue( 215 std::shared_ptr<ReadyQueue> cpu_ready_queue, 216 at::Device device); 217 std::shared_ptr<ReadyQueue> ready_queue_by_index( 218 std::shared_ptr<ReadyQueue> cpu_ready_queue, 219 int device_index); 220 // start device threads (CUDA, XLA, etc.) in Engine, 221 // note that it does NOT start CPU thread. 222 void start_device_threads(); 223 void increment_non_reentrant_thread_count(); 224 void decrement_non_reentrant_thread_count(); 225 virtual void thread_main(const std::shared_ptr<GraphTask>& task); 226 void reentrant_thread_init(); 227 void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task); 228 229 // Ensures device_ready_queues_ are initialized only once 230 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 231 c10::once_flag start_device_threads_flag_; 232 // Safe to read device_ready_queues_ without synchronization after 233 // initialization 234 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 235 std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_; 236 237 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 238 std::vector<std::function<void()>> final_callbacks_; 239 // To protect reads and writes to final_callbacks_ 240 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 241 std::mutex post_callbacks_lock_; 242 243 // How many nested reentrant calls are allowed until a new thread is used 244 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 245 int max_recursion_depth_; 246 247 struct ThreadPoolShared { 248 // Data structures used by the threads for executing reentrant backwards 249 // tasks. See Note [Reentrant backwards] 250 // Number of available threads for processing new GraphTasks. 251 unsigned int num_workers_{0}; 252 // The threads will wait on work_ to be notified of GraphTasks 253 std::condition_variable work_; 254 // To protect reads and writes to graphtask_queue_ and num_workers_ 255 // and for synchronizing creating new threads when needed 256 std::mutex mutex_; 257 // Workers will process the GraphTasks added to this queue. A GraphTask is 258 // allocated inside Engine::execute and lives for the duration of execute 259 std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_; 260 261 ThreadPoolShared() = default; 262 }; 263 264 // Temporary workaround until shutting down threads is done 265 // We need shared ownership of all these objects because the threads are 266 // leaked when Engine shuts down, so there may be threads waiting on work_ for 267 // the graphtasks_queue_ to be nonempty. 268 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 269 std::shared_ptr<ThreadPoolShared> thread_pool_shared_; 270 271 private: 272 // Number of non-reentrant threads 273 std::atomic<uint32_t> non_reentrant_device_thread_count_; 274 // Destructor will wait for non-reentrant threads to finish 275 std::condition_variable non_reentrant_device_thread_condvar_; 276 std::mutex non_reentrant_device_thread_mutex_; 277 // stop() must be called before the destruction path goes down to the base 278 // class, in order to avoid a data-race-on-vptr. Use this boolean to guard 279 // whether stop() has already been called, so we can call this in every 280 // destructor of the class hierarchy. 281 bool stopped_{false}; 282 }; 283 284 // allow python_engine to override the default engine when it loads 285 using EngineStub = Engine& (*)(); 286 TORCH_API void set_default_engine_stub(EngineStub stub); 287 288 } // namespace torch::autograd 289