xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Backoff.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/Backoff.hpp>
2 
3 #include <exception>
4 #include <stdexcept>
5 
6 namespace c10d {
7 namespace {
8 constexpr std::chrono::milliseconds kZeroInterval{0};
9 
randSeed()10 int32_t randSeed() {
11   std::random_device rd;
12   return rd();
13 }
14 } // namespace
15 
ExponentialBackoffWithJitter()16 ExponentialBackoffWithJitter::ExponentialBackoffWithJitter()
17     : gen_(randSeed()) {}
18 
nextBackoff()19 std::chrono::milliseconds ExponentialBackoffWithJitter::nextBackoff() {
20   if (initialInterval == kZeroInterval) {
21     throw std::out_of_range(
22         "ExponentialBackoffWithJitter requires non-zero initial interval");
23   }
24   if (initialInterval > maxInterval) {
25     throw std::out_of_range(
26         "ExponentialBackoffWithJitter requires initialInterval <= maxInterval");
27   }
28   if (randomizationFactor >= 1 || randomizationFactor < 0) {
29     throw std::out_of_range(
30         "ExponentialBackoffWithJitter requires randomization factor (0,1]");
31   }
32   if (multiplier < 1.0) {
33     throw std::out_of_range(
34         "ExponentialBackoffWithJitter requires multiplier >=1");
35   }
36 
37   // detect initial setup
38   if (currentInterval_ == kZeroInterval) {
39     currentInterval_ = initialInterval;
40   }
41 
42   // sample current interval
43   std::chrono::milliseconds randomization{static_cast<int64_t>(
44       randomizationFactor * static_cast<double>(currentInterval_.count()))};
45   std::chrono::milliseconds minSampleInterval =
46       currentInterval_ - randomization;
47   std::chrono::milliseconds maxSampleInterval =
48       currentInterval_ + randomization;
49 
50   std::uniform_int_distribution<> dist(
51       minSampleInterval.count(), maxSampleInterval.count());
52   std::chrono::milliseconds backoffInterval{dist(gen_)};
53 
54   // update current interval
55   currentInterval_ = std::chrono::milliseconds(static_cast<int64_t>(
56       static_cast<double>(currentInterval_.count()) * multiplier));
57 
58   if (currentInterval_ > maxInterval) {
59     currentInterval_ = maxInterval;
60   }
61 
62   return backoffInterval;
63 }
64 
reset()65 void ExponentialBackoffWithJitter::reset() {
66   currentInterval_ = kZeroInterval;
67 }
68 
FixedBackoff(std::chrono::milliseconds interval)69 FixedBackoff::FixedBackoff(std::chrono::milliseconds interval)
70     : interval_(interval) {}
71 
nextBackoff()72 std::chrono::milliseconds FixedBackoff::nextBackoff() {
73   return interval_;
74 }
75 
reset()76 void FixedBackoff::reset() {}
77 } // namespace c10d
78