xref: /aosp_15_r20/external/pytorch/aten/src/ATen/CPUGeneratorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Generator.h>
4 #include <ATen/core/MT19937RNGEngine.h>
5 #include <c10/core/GeneratorImpl.h>
6 #include <optional>
7 
8 namespace at {
9 
10 struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
11   // Constructors
12   CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val);
13   ~CPUGeneratorImpl() override = default;
14 
15   // CPUGeneratorImpl methods
16   std::shared_ptr<CPUGeneratorImpl> clone() const;
17   void set_current_seed(uint64_t seed) override;
18   void set_offset(uint64_t offset) override;
19   uint64_t get_offset() const override;
20   uint64_t current_seed() const override;
21   uint64_t seed() override;
22   void set_state(const c10::TensorImpl& new_state) override;
23   c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
24   static c10::DeviceType device_type();
25   uint32_t random();
26   uint64_t random64();
27   std::optional<float> next_float_normal_sample();
28   std::optional<double> next_double_normal_sample();
29   void set_next_float_normal_sample(std::optional<float> randn);
30   void set_next_double_normal_sample(std::optional<double> randn);
31   at::mt19937 engine();
32   void set_engine(at::mt19937 engine);
33 
34  private:
35   CPUGeneratorImpl* clone_impl() const override;
36   at::mt19937 engine_;
37   std::optional<float> next_float_normal_sample_;
38   std::optional<double> next_double_normal_sample_;
39 };
40 
41 namespace detail {
42 
43 TORCH_API const Generator& getDefaultCPUGenerator();
44 TORCH_API Generator
45 createCPUGenerator(uint64_t seed_val = default_rng_seed_val);
46 
47 } // namespace detail
48 
49 } // namespace at
50