1 #include <torch/csrc/distributed/c10d/Backoff.hpp> 2 3 #include <exception> 4 #include <stdexcept> 5 6 namespace c10d { 7 namespace { 8 constexpr std::chrono::milliseconds kZeroInterval{0}; 9 randSeed()10int32_t randSeed() { 11 std::random_device rd; 12 return rd(); 13 } 14 } // namespace 15 ExponentialBackoffWithJitter()16ExponentialBackoffWithJitter::ExponentialBackoffWithJitter() 17 : gen_(randSeed()) {} 18 nextBackoff()19std::chrono::milliseconds ExponentialBackoffWithJitter::nextBackoff() { 20 if (initialInterval == kZeroInterval) { 21 throw std::out_of_range( 22 "ExponentialBackoffWithJitter requires non-zero initial interval"); 23 } 24 if (initialInterval > maxInterval) { 25 throw std::out_of_range( 26 "ExponentialBackoffWithJitter requires initialInterval <= maxInterval"); 27 } 28 if (randomizationFactor >= 1 || randomizationFactor < 0) { 29 throw std::out_of_range( 30 "ExponentialBackoffWithJitter requires randomization factor (0,1]"); 31 } 32 if (multiplier < 1.0) { 33 throw std::out_of_range( 34 "ExponentialBackoffWithJitter requires multiplier >=1"); 35 } 36 37 // detect initial setup 38 if (currentInterval_ == kZeroInterval) { 39 currentInterval_ = initialInterval; 40 } 41 42 // sample current interval 43 std::chrono::milliseconds randomization{static_cast<int64_t>( 44 randomizationFactor * static_cast<double>(currentInterval_.count()))}; 45 std::chrono::milliseconds minSampleInterval = 46 currentInterval_ - randomization; 47 std::chrono::milliseconds maxSampleInterval = 48 currentInterval_ + randomization; 49 50 std::uniform_int_distribution<> dist( 51 minSampleInterval.count(), maxSampleInterval.count()); 52 std::chrono::milliseconds backoffInterval{dist(gen_)}; 53 54 // update current interval 55 currentInterval_ = std::chrono::milliseconds(static_cast<int64_t>( 56 static_cast<double>(currentInterval_.count()) * multiplier)); 57 58 if (currentInterval_ > maxInterval) { 59 currentInterval_ = maxInterval; 60 } 61 62 return backoffInterval; 63 } 64 reset()65void ExponentialBackoffWithJitter::reset() { 66 currentInterval_ = kZeroInterval; 67 } 68 FixedBackoff(std::chrono::milliseconds interval)69FixedBackoff::FixedBackoff(std::chrono::milliseconds interval) 70 : interval_(interval) {} 71 nextBackoff()72std::chrono::milliseconds FixedBackoff::nextBackoff() { 73 return interval_; 74 } 75 reset()76void FixedBackoff::reset() {} 77 } // namespace c10d 78