1 #include <mutex> 2 #include <ATen/core/GeneratorForPrivateuseone.h> 3 4 namespace at { 5 6 static std::mutex _generator_mutex_lock; 7 GetGeneratorPrivate()8std::optional<GeneratorFuncType>& GetGeneratorPrivate() { 9 static std::optional<GeneratorFuncType> generator_privateuse1 = std::nullopt; 10 return generator_privateuse1; 11 } 12 _GeneratorRegister(const GeneratorFuncType & func)13_GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) { 14 std::lock_guard<std::mutex> lock(_generator_mutex_lock); 15 TORCH_CHECK( 16 !GetGeneratorPrivate().has_value(), 17 "Only can register a generator to the PrivateUse1 dispatch key once!"); 18 19 auto& m_generator = GetGeneratorPrivate(); 20 m_generator = func; 21 } 22 GetGeneratorForPrivateuse1(c10::DeviceIndex device_index)23at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) { 24 TORCH_CHECK( 25 GetGeneratorPrivate().has_value(), 26 "Please register a generator to the PrivateUse1 dispatch key, \ 27 using the REGISTER_GENERATOR_PRIVATEUSE1 macro."); 28 29 // NOLINTNEXTLINE(bugprone-unchecked-optional-access) 30 return GetGeneratorPrivate().value()(device_index); 31 } 32 33 } // namespace at 34