xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/GeneratorForPrivateuseone.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <mutex>
2 #include <ATen/core/GeneratorForPrivateuseone.h>
3 
4 namespace at {
5 
6 static std::mutex _generator_mutex_lock;
7 
GetGeneratorPrivate()8 std::optional<GeneratorFuncType>& GetGeneratorPrivate() {
9   static std::optional<GeneratorFuncType> generator_privateuse1 = std::nullopt;
10   return generator_privateuse1;
11 }
12 
_GeneratorRegister(const GeneratorFuncType & func)13 _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
14   std::lock_guard<std::mutex> lock(_generator_mutex_lock);
15   TORCH_CHECK(
16       !GetGeneratorPrivate().has_value(),
17       "Only can register a generator to the PrivateUse1 dispatch key once!");
18 
19   auto& m_generator = GetGeneratorPrivate();
20   m_generator = func;
21 }
22 
GetGeneratorForPrivateuse1(c10::DeviceIndex device_index)23 at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) {
24   TORCH_CHECK(
25       GetGeneratorPrivate().has_value(),
26       "Please register a generator to the PrivateUse1 dispatch key, \
27       using the REGISTER_GENERATOR_PRIVATEUSE1 macro.");
28 
29   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
30   return GetGeneratorPrivate().value()(device_index);
31 }
32 
33 } // namespace at
34