xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/HashStore.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <condition_variable>
4 #include <mutex>
5 #include <unordered_map>
6 
7 #include <torch/csrc/distributed/c10d/Store.hpp>
8 
9 namespace c10d {
10 
11 class TORCH_API HashStore : public Store {
12  public:
13   ~HashStore() override = default;
14 
15   void set(const std::string& key, const std::vector<uint8_t>& data) override;
16 
17   std::vector<uint8_t> compareSet(
18       const std::string& key,
19       const std::vector<uint8_t>& expectedValue,
20       const std::vector<uint8_t>& desiredValue) override;
21 
22   std::vector<uint8_t> get(const std::string& key) override;
23 
wait(const std::vector<std::string> & keys)24   void wait(const std::vector<std::string>& keys) override {
25     wait(keys, timeout_);
26   }
27 
28   void wait(
29       const std::vector<std::string>& keys,
30       const std::chrono::milliseconds& timeout) override;
31 
32   int64_t add(const std::string& key, int64_t value) override;
33 
34   int64_t getNumKeys() override;
35 
36   bool check(const std::vector<std::string>& keys) override;
37 
38   bool deleteKey(const std::string& key) override;
39 
40   void append(const std::string& key, const std::vector<uint8_t>& value)
41       override;
42 
43   std::vector<std::vector<uint8_t>> multiGet(
44       const std::vector<std::string>& keys) override;
45 
46   void multiSet(
47       const std::vector<std::string>& keys,
48       const std::vector<std::vector<uint8_t>>& values) override;
49 
50   // Returns true if this store support append, multiGet and multiSet
51   bool hasExtendedApi() const override;
52 
53  protected:
54   std::unordered_map<std::string, std::vector<uint8_t>> map_;
55   std::mutex m_;
56   std::condition_variable cv_;
57 };
58 
59 } // namespace c10d
60