xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/utils/warnings.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/util/Exception.h>
3 
4 #include <mutex>
5 #include <vector>
6 
7 namespace torch::autograd::utils {
8 
9 // Warning handler for multi-threaded contexts. Gather warnings from
10 // all threads into a single queue, then process together at the end
11 // in the main thread.
12 class DelayWarningHandler : public at::WarningHandler {
13  public:
14   ~DelayWarningHandler() override = default;
15   void replay_warnings();
16 
17  private:
18   void process(const c10::Warning& warning) override;
19 
20   std::vector<c10::Warning> warnings_;
21   std::mutex mutex_;
22 };
23 
24 } // namespace torch::autograd::utils
25