1 #pragma once 2 3 #include <ATen/core/Generator.h> 4 #include <c10/util/intrusive_ptr.h> 5 6 namespace at { 7 8 using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>; 9 10 std::optional<GeneratorFuncType>& GetGeneratorPrivate(); 11 12 class TORCH_API _GeneratorRegister { 13 public: 14 explicit _GeneratorRegister(const GeneratorFuncType& func); 15 }; 16 17 TORCH_API at::Generator GetGeneratorForPrivateuse1( 18 c10::DeviceIndex device_index); 19 20 /** 21 * This is used to register Generator to PyTorch for `privateuse1` key. 22 * 23 * Usage: REGISTER_GENERATOR_PRIVATEUSE1(MakeGeneratorForPrivateuse1) 24 * 25 * class CustomGeneratorImpl : public c10::GeneratorImpl { 26 * CustomGeneratorImpl(DeviceIndex device_index = -1); 27 * explicit ~CustomGeneratorImpl() override = default; 28 * ... 29 * }; 30 * 31 * at::Generator MakeGeneratorForPrivateuse1(c10::DeviceIndex id) { 32 * return at::make_generator<CustomGeneratorImpl>(id); 33 * } 34 */ 35 36 #define REGISTER_GENERATOR_PRIVATEUSE1(GeneratorPrivate) \ 37 static auto temp##GeneratorPrivate = at::_GeneratorRegister(GeneratorPrivate); 38 39 } // namespace at 40