1 #pragma once 2 3 #include <chrono> 4 #include <cstdint> 5 #include <string> 6 #include <vector> 7 8 #include <c10/macros/Macros.h> 9 #include <torch/custom_class.h> 10 11 namespace c10d { 12 13 // callback function will be given arguments (std::optional<string> oldValue, 14 // std::optional<string> newValue) 15 using WatchKeyCallback = 16 std::function<void(std::optional<std::string>, std::optional<std::string>)>; 17 18 class TORCH_API Store : public torch::CustomClassHolder { 19 public: 20 static constexpr std::chrono::milliseconds kDefaultTimeout = 21 std::chrono::seconds(300); 22 static constexpr std::chrono::milliseconds kNoTimeout = 23 std::chrono::milliseconds::zero(); 24 Store()25 Store() : timeout_(kDefaultTimeout) {} 26 Store(const std::chrono::milliseconds & timeout)27 explicit Store(const std::chrono::milliseconds& timeout) 28 : timeout_(timeout) {} 29 30 Store(const Store&) = default; 31 Store(Store&&) noexcept = default; 32 33 ~Store() override = default; 34 35 void set(const std::string& key, const std::string& value); 36 37 virtual void set( 38 const std::string& key, 39 const std::vector<uint8_t>& value) = 0; 40 41 std::string compareSet( 42 const std::string& key, 43 const std::string& currentValue, 44 const std::string& newValue); 45 compareSet(const std::string & key,const std::vector<uint8_t> & currentValue,const std::vector<uint8_t> & newValue)46 virtual std::vector<uint8_t> compareSet( 47 const std::string& key, 48 const std::vector<uint8_t>& currentValue, 49 const std::vector<uint8_t>& newValue) { 50 TORCH_INTERNAL_ASSERT(false, "Not implemented."); 51 } 52 53 std::string get_to_str(const std::string& key); 54 55 virtual std::vector<uint8_t> get(const std::string& key) = 0; 56 57 virtual int64_t add(const std::string& key, int64_t value) = 0; 58 59 virtual bool deleteKey(const std::string& key) = 0; 60 61 virtual bool check(const std::vector<std::string>& keys) = 0; 62 63 virtual int64_t getNumKeys() = 0; 64 65 virtual void wait(const std::vector<std::string>& keys) = 0; 66 67 virtual void wait( 68 const std::vector<std::string>& keys, 69 const std::chrono::milliseconds& timeout) = 0; 70 71 virtual const std::chrono::milliseconds& getTimeout() const noexcept; 72 73 virtual void setTimeout(const std::chrono::milliseconds& timeout); 74 75 // watchKey() is deprecated and no longer supported. watchKey(const std::string &,WatchKeyCallback)76 virtual void watchKey( 77 const std::string& /* unused */, 78 WatchKeyCallback /* unused */) { 79 TORCH_CHECK(false, "watchKey is deprecated, no implementation support it."); 80 } 81 82 virtual void append( 83 const std::string& key, 84 const std::vector<uint8_t>& value); 85 86 virtual std::vector<std::vector<uint8_t>> multiGet( 87 const std::vector<std::string>& keys); 88 89 virtual void multiSet( 90 const std::vector<std::string>& keys, 91 const std::vector<std::vector<uint8_t>>& values); 92 93 // Returns true if this store support append, multiGet and multiSet 94 virtual bool hasExtendedApi() const; 95 96 protected: 97 std::chrono::milliseconds timeout_; 98 }; 99 100 /* 101 StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it 102 when it returns. 103 */ 104 class StoreTimeoutGuard { 105 public: StoreTimeoutGuard(Store & store,const std::chrono::milliseconds & timeout)106 explicit StoreTimeoutGuard( 107 Store& store, 108 const std::chrono::milliseconds& timeout) 109 : store_(store), oldTimeout_(store.getTimeout()) { 110 store.setTimeout(timeout); 111 } 112 ~StoreTimeoutGuard()113 ~StoreTimeoutGuard() { 114 store_.setTimeout(oldTimeout_); 115 } 116 117 /* Disabling copy and move semantics */ 118 StoreTimeoutGuard(const StoreTimeoutGuard&) = delete; 119 StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete; 120 StoreTimeoutGuard(StoreTimeoutGuard&&) = delete; 121 StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete; 122 123 private: 124 Store& store_; 125 std::chrono::milliseconds oldTimeout_{}; 126 }; 127 128 } // namespace c10d 129