xref: /aosp_15_r20/external/pytorch/aten/src/ATen/xpu/XPUGeneratorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Utils.h>
2 #include <ATen/xpu/XPUGeneratorImpl.h>
3 #include <c10/core/StreamGuard.h>
4 #include <c10/util/CallOnce.h>
5 #include <c10/xpu/XPUFunctions.h>
6 
7 namespace at {
8 namespace xpu::detail {
9 namespace {
10 
11 /*
12  * Currently, there is one generator pool containing XPU generator per device.
13  * Each generator is lazily initialized the first time generator is
14  * requested for a device.
15  */
16 c10::once_flag init_flag;
17 DeviceIndex num_gpus = -1;
18 std::deque<c10::once_flag> xpu_gens_init_flag;
19 std::vector<Generator> default_gens_xpu;
20 
initXPUGenVector()21 void initXPUGenVector() {
22   num_gpus = device_count();
23   xpu_gens_init_flag.resize(num_gpus);
24   default_gens_xpu.resize(num_gpus);
25 }
26 
check_device(DeviceIndex device)27 inline void check_device(DeviceIndex device) {
28   TORCH_CHECK(
29       device >= 0 && device < num_gpus,
30       "device is out of range, device is ",
31       static_cast<int16_t>(device),
32       ", total number of device is ",
33       static_cast<int16_t>(num_gpus),
34       ".");
35 }
36 
37 } // anonymous namespace
38 
39 // Get the default generator with a random seed for a specific xpu device.
getDefaultXPUGenerator(DeviceIndex device)40 const Generator& getDefaultXPUGenerator(DeviceIndex device) {
41   c10::call_once(init_flag, initXPUGenVector);
42   if (device == -1) {
43     device = c10::xpu::current_device();
44   }
45   check_device(device);
46   c10::call_once(xpu_gens_init_flag[device], [&]() {
47     default_gens_xpu[device] = make_generator<XPUGeneratorImpl>(device);
48     default_gens_xpu[device].seed();
49   });
50   return default_gens_xpu[device];
51 }
52 
53 // Create a generator with a fixed seed for a specific xpu device.
createXPUGenerator(DeviceIndex device)54 Generator createXPUGenerator(DeviceIndex device) {
55   c10::call_once(init_flag, initXPUGenVector);
56   if (device == -1) {
57     device = c10::xpu::current_device();
58   }
59   check_device(device);
60   auto gen = make_generator<XPUGeneratorImpl>(device);
61   auto xpu_gen = check_generator<XPUGeneratorImpl>(gen);
62   xpu_gen->set_current_seed(default_rng_seed_val);
63   xpu_gen->set_philox_offset_per_thread(0);
64   return gen;
65 }
66 
67 } // namespace xpu::detail
68 
XPUGeneratorImpl(DeviceIndex device_index)69 XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index)
70     : GeneratorImpl{
71           Device(DeviceType::XPU, device_index),
72           DispatchKeySet(c10::DispatchKey::XPU)} {}
73 
set_current_seed(uint64_t seed)74 void XPUGeneratorImpl::set_current_seed(uint64_t seed) {
75   seed_ = seed;
76   set_philox_offset_per_thread(0);
77 }
78 
set_offset(uint64_t offset)79 void XPUGeneratorImpl::set_offset(uint64_t offset) {
80   set_philox_offset_per_thread(offset);
81 }
82 
get_offset() const83 uint64_t XPUGeneratorImpl::get_offset() const {
84   return philox_offset_per_thread_;
85 }
86 
current_seed() const87 uint64_t XPUGeneratorImpl::current_seed() const {
88   return seed_;
89 }
90 
seed()91 uint64_t XPUGeneratorImpl::seed() {
92   auto random = c10::detail::getNonDeterministicRandom(true);
93   this->set_current_seed(random);
94   return random;
95 }
96 
get_state() const97 c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
98   // The RNG state comprises the seed, and an offset used for Philox.
99   static const size_t seed_size = sizeof(uint64_t);
100   static const size_t offset_size = sizeof(uint64_t);
101   static const size_t total_size = seed_size + offset_size;
102 
103   // The internal state is returned as a CPU byte tensor.
104   auto state_tensor = at::detail::empty_cpu(
105       {static_cast<int64_t>(total_size)},
106       ScalarType::Byte,
107       std::nullopt,
108       std::nullopt,
109       std::nullopt,
110       std::nullopt);
111   auto rng_state = state_tensor.data_ptr<uint8_t>();
112   auto current_seed = this->current_seed();
113   auto offset = this->philox_offset_per_thread();
114   memcpy(rng_state, &current_seed, seed_size);
115   memcpy(rng_state + seed_size, &offset, offset_size);
116 
117   return state_tensor.getIntrusivePtr();
118 }
119 
set_state(const c10::TensorImpl & new_state)120 void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
121   static const size_t seed_size = sizeof(uint64_t);
122   static const size_t offset_size = sizeof(uint64_t);
123   static const size_t total_size = seed_size + offset_size;
124 
125   at::detail::check_rng_state(new_state);
126   auto new_state_size = new_state.numel();
127   TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
128 
129   uint64_t input_seed;
130   auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
131   memcpy(&input_seed, new_rng_state, seed_size);
132   this->set_current_seed(input_seed);
133   uint64_t philox_offset;
134   memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
135   this->set_philox_offset_per_thread(philox_offset);
136 }
137 
set_philox_offset_per_thread(uint64_t offset)138 void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
139   TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
140   philox_offset_per_thread_ = offset;
141 }
142 
philox_offset_per_thread() const143 uint64_t XPUGeneratorImpl::philox_offset_per_thread() const {
144   return philox_offset_per_thread_;
145 }
146 
philox_engine_inputs(uint64_t increment)147 std::pair<uint64_t, uint64_t> XPUGeneratorImpl::philox_engine_inputs(
148     uint64_t increment) {
149   increment = ((increment + 3) / 4) * 4;
150   TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
151   uint64_t offset = this->philox_offset_per_thread_;
152   this->philox_offset_per_thread_ += increment;
153   return std::make_pair(this->seed_, offset);
154 }
155 
device_type()156 DeviceType XPUGeneratorImpl::device_type() {
157   return DeviceType::XPU;
158 }
159 
clone() const160 std::shared_ptr<XPUGeneratorImpl> XPUGeneratorImpl::clone() const {
161   return std::shared_ptr<XPUGeneratorImpl>(this->clone_impl());
162 }
163 
clone_impl() const164 XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const {
165   auto gen = new XPUGeneratorImpl(this->device().index());
166   gen->set_current_seed(this->seed_);
167   gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
168   return gen;
169 }
170 
171 } // namespace at
172