1 #pragma once 2 3 #include <condition_variable> 4 #include <mutex> 5 #include <unordered_map> 6 7 #include <torch/csrc/distributed/c10d/Store.hpp> 8 9 namespace c10d { 10 11 class TORCH_API HashStore : public Store { 12 public: 13 ~HashStore() override = default; 14 15 void set(const std::string& key, const std::vector<uint8_t>& data) override; 16 17 std::vector<uint8_t> compareSet( 18 const std::string& key, 19 const std::vector<uint8_t>& expectedValue, 20 const std::vector<uint8_t>& desiredValue) override; 21 22 std::vector<uint8_t> get(const std::string& key) override; 23 wait(const std::vector<std::string> & keys)24 void wait(const std::vector<std::string>& keys) override { 25 wait(keys, timeout_); 26 } 27 28 void wait( 29 const std::vector<std::string>& keys, 30 const std::chrono::milliseconds& timeout) override; 31 32 int64_t add(const std::string& key, int64_t value) override; 33 34 int64_t getNumKeys() override; 35 36 bool check(const std::vector<std::string>& keys) override; 37 38 bool deleteKey(const std::string& key) override; 39 40 void append(const std::string& key, const std::vector<uint8_t>& value) 41 override; 42 43 std::vector<std::vector<uint8_t>> multiGet( 44 const std::vector<std::string>& keys) override; 45 46 void multiSet( 47 const std::vector<std::string>& keys, 48 const std::vector<std::vector<uint8_t>>& values) override; 49 50 // Returns true if this store support append, multiGet and multiSet 51 bool hasExtendedApi() const override; 52 53 protected: 54 std::unordered_map<std::string, std::vector<uint8_t>> map_; 55 std::mutex m_; 56 std::condition_variable cv_; 57 }; 58 59 } // namespace c10d 60