xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/GeneratorForPrivateuseone.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Generator.h>
4 #include <c10/util/intrusive_ptr.h>
5 
6 namespace at {
7 
8 using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
9 
10 std::optional<GeneratorFuncType>& GetGeneratorPrivate();
11 
12 class TORCH_API _GeneratorRegister {
13  public:
14   explicit _GeneratorRegister(const GeneratorFuncType& func);
15 };
16 
17 TORCH_API at::Generator GetGeneratorForPrivateuse1(
18     c10::DeviceIndex device_index);
19 
20 /**
21  * This is used to register Generator to PyTorch for `privateuse1` key.
22  *
23  * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1)
24  *
25  * class CustomGeneratorImpl : public c10::GeneratorImpl {
26  *   CustomGeneratorImpl(DeviceIndex device_index = -1);
27  *   explicit ~CustomGeneratorImpl() override = default;
28  *   ...
29  * };
30  *
31  * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) {
32  *   return at::make_generator<CustomGeneratorImpl>(id);
33  * }
34  */
35 
36 #define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \
37   static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate);
38 
39 } // namespace at
40