1 #pragma once 2 3 #include <ATen/core/Generator.h> 4 #include <ATen/core/MT19937RNGEngine.h> 5 #include <c10/core/GeneratorImpl.h> 6 #include <optional> 7 8 namespace at { 9 10 struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl { 11 // Constructors 12 CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val); 13 ~CPUGeneratorImpl() override = default; 14 15 // CPUGeneratorImpl methods 16 std::shared_ptr<CPUGeneratorImpl> clone() const; 17 void set_current_seed(uint64_t seed) override; 18 void set_offset(uint64_t offset) override; 19 uint64_t get_offset() const override; 20 uint64_t current_seed() const override; 21 uint64_t seed() override; 22 void set_state(const c10::TensorImpl& new_state) override; 23 c10::intrusive_ptr<c10::TensorImpl> get_state() const override; 24 static c10::DeviceType device_type(); 25 uint32_t random(); 26 uint64_t random64(); 27 std::optional<float> next_float_normal_sample(); 28 std::optional<double> next_double_normal_sample(); 29 void set_next_float_normal_sample(std::optional<float> randn); 30 void set_next_double_normal_sample(std::optional<double> randn); 31 at::mt19937 engine(); 32 void set_engine(at::mt19937 engine); 33 34 private: 35 CPUGeneratorImpl* clone_impl() const override; 36 at::mt19937 engine_; 37 std::optional<float> next_float_normal_sample_; 38 std::optional<double> next_double_normal_sample_; 39 }; 40 41 namespace detail { 42 43 TORCH_API const Generator& getDefaultCPUGenerator(); 44 TORCH_API Generator 45 createCPUGenerator(uint64_t seed_val = default_rng_seed_val); 46 47 } // namespace detail 48 49 } // namespace at 50