1 #pragma once 2 #include <c10/util/Exception.h> 3 4 #include <mutex> 5 #include <vector> 6 7 namespace torch::autograd::utils { 8 9 // Warning handler for multi-threaded contexts. Gather warnings from 10 // all threads into a single queue, then process together at the end 11 // in the main thread. 12 class DelayWarningHandler : public at::WarningHandler { 13 public: 14 ~DelayWarningHandler() override = default; 15 void replay_warnings(); 16 17 private: 18 void process(const c10::Warning& warning) override; 19 20 std::vector<c10::Warning> warnings_; 21 std::mutex mutex_; 22 }; 23 24 } // namespace torch::autograd::utils 25