xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Backoff.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <chrono>
4 #include <random>
5 #include <thread>
6 
7 #include <c10/macros/Macros.h>
8 
9 namespace c10d {
10 
11 class TORCH_API Backoff {
12  public:
13   virtual ~Backoff() = default;
14 
15   virtual std::chrono::milliseconds nextBackoff() = 0;
16   virtual void reset() = 0;
17 
sleepBackoff()18   void sleepBackoff() {
19     std::this_thread::sleep_for(nextBackoff());
20   }
21 };
22 
23 class TORCH_API ExponentialBackoffWithJitter : public Backoff {
24  public:
25   ExponentialBackoffWithJitter();
26 
27   std::chrono::milliseconds nextBackoff() override;
28   void reset() override;
29 
30  public:
31   std::chrono::milliseconds initialInterval{500};
32   double randomizationFactor{0.5};
33   double multiplier{1.5};
34   std::chrono::milliseconds maxInterval{60000};
35 
36  private:
37   std::mt19937 gen_;
38   std::chrono::milliseconds currentInterval_{0};
39 };
40 
41 class TORCH_API FixedBackoff : public Backoff {
42  public:
43   FixedBackoff(std::chrono::milliseconds interval);
44 
45   std::chrono::milliseconds nextBackoff() override;
46   void reset() override;
47 
48  private:
49   std::chrono::milliseconds interval_;
50 };
51 
52 } // namespace c10d
53