xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Store.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <chrono>
4 #include <cstdint>
5 #include <string>
6 #include <vector>
7 
8 #include <c10/macros/Macros.h>
9 #include <torch/custom_class.h>
10 
11 namespace c10d {
12 
13 // callback function will be given arguments (std::optional<string> oldValue,
14 // std::optional<string> newValue)
15 using WatchKeyCallback =
16     std::function<void(std::optional<std::string>, std::optional<std::string>)>;
17 
18 class TORCH_API Store : public torch::CustomClassHolder {
19  public:
20   static constexpr std::chrono::milliseconds kDefaultTimeout =
21       std::chrono::seconds(300);
22   static constexpr std::chrono::milliseconds kNoTimeout =
23       std::chrono::milliseconds::zero();
24 
Store()25   Store() : timeout_(kDefaultTimeout) {}
26 
Store(const std::chrono::milliseconds & timeout)27   explicit Store(const std::chrono::milliseconds& timeout)
28       : timeout_(timeout) {}
29 
30   Store(const Store&) = default;
31   Store(Store&&) noexcept = default;
32 
33   ~Store() override = default;
34 
35   void set(const std::string& key, const std::string& value);
36 
37   virtual void set(
38       const std::string& key,
39       const std::vector<uint8_t>& value) = 0;
40 
41   std::string compareSet(
42       const std::string& key,
43       const std::string& currentValue,
44       const std::string& newValue);
45 
compareSet(const std::string & key,const std::vector<uint8_t> & currentValue,const std::vector<uint8_t> & newValue)46   virtual std::vector<uint8_t> compareSet(
47       const std::string& key,
48       const std::vector<uint8_t>& currentValue,
49       const std::vector<uint8_t>& newValue) {
50     TORCH_INTERNAL_ASSERT(false, "Not implemented.");
51   }
52 
53   std::string get_to_str(const std::string& key);
54 
55   virtual std::vector<uint8_t> get(const std::string& key) = 0;
56 
57   virtual int64_t add(const std::string& key, int64_t value) = 0;
58 
59   virtual bool deleteKey(const std::string& key) = 0;
60 
61   virtual bool check(const std::vector<std::string>& keys) = 0;
62 
63   virtual int64_t getNumKeys() = 0;
64 
65   virtual void wait(const std::vector<std::string>& keys) = 0;
66 
67   virtual void wait(
68       const std::vector<std::string>& keys,
69       const std::chrono::milliseconds& timeout) = 0;
70 
71   virtual const std::chrono::milliseconds& getTimeout() const noexcept;
72 
73   virtual void setTimeout(const std::chrono::milliseconds& timeout);
74 
75   // watchKey() is deprecated and no longer supported.
watchKey(const std::string &,WatchKeyCallback)76   virtual void watchKey(
77       const std::string& /* unused */,
78       WatchKeyCallback /* unused */) {
79     TORCH_CHECK(false, "watchKey is deprecated, no implementation support it.");
80   }
81 
82   virtual void append(
83       const std::string& key,
84       const std::vector<uint8_t>& value);
85 
86   virtual std::vector<std::vector<uint8_t>> multiGet(
87       const std::vector<std::string>& keys);
88 
89   virtual void multiSet(
90       const std::vector<std::string>& keys,
91       const std::vector<std::vector<uint8_t>>& values);
92 
93   // Returns true if this store support append, multiGet and multiSet
94   virtual bool hasExtendedApi() const;
95 
96  protected:
97   std::chrono::milliseconds timeout_;
98 };
99 
100 /*
101 StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it
102 when it returns.
103 */
104 class StoreTimeoutGuard {
105  public:
StoreTimeoutGuard(Store & store,const std::chrono::milliseconds & timeout)106   explicit StoreTimeoutGuard(
107       Store& store,
108       const std::chrono::milliseconds& timeout)
109       : store_(store), oldTimeout_(store.getTimeout()) {
110     store.setTimeout(timeout);
111   }
112 
~StoreTimeoutGuard()113   ~StoreTimeoutGuard() {
114     store_.setTimeout(oldTimeout_);
115   }
116 
117   /* Disabling copy and move semantics */
118   StoreTimeoutGuard(const StoreTimeoutGuard&) = delete;
119   StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete;
120   StoreTimeoutGuard(StoreTimeoutGuard&&) = delete;
121   StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete;
122 
123  private:
124   Store& store_;
125   std::chrono::milliseconds oldTimeout_{};
126 };
127 
128 } // namespace c10d
129