1 #include <ATen/Utils.h>
2 #include <ATen/xpu/XPUGeneratorImpl.h>
3 #include <c10/core/StreamGuard.h>
4 #include <c10/util/CallOnce.h>
5 #include <c10/xpu/XPUFunctions.h>
6
7 namespace at {
8 namespace xpu::detail {
9 namespace {
10
11 /*
12 * Currently, there is one generator pool containing XPU generator per device.
13 * Each generator is lazily initialized the first time generator is
14 * requested for a device.
15 */
16 c10::once_flag init_flag;
17 DeviceIndex num_gpus = -1;
18 std::deque<c10::once_flag> xpu_gens_init_flag;
19 std::vector<Generator> default_gens_xpu;
20
initXPUGenVector()21 void initXPUGenVector() {
22 num_gpus = device_count();
23 xpu_gens_init_flag.resize(num_gpus);
24 default_gens_xpu.resize(num_gpus);
25 }
26
check_device(DeviceIndex device)27 inline void check_device(DeviceIndex device) {
28 TORCH_CHECK(
29 device >= 0 && device < num_gpus,
30 "device is out of range, device is ",
31 static_cast<int16_t>(device),
32 ", total number of device is ",
33 static_cast<int16_t>(num_gpus),
34 ".");
35 }
36
37 } // anonymous namespace
38
39 // Get the default generator with a random seed for a specific xpu device.
getDefaultXPUGenerator(DeviceIndex device)40 const Generator& getDefaultXPUGenerator(DeviceIndex device) {
41 c10::call_once(init_flag, initXPUGenVector);
42 if (device == -1) {
43 device = c10::xpu::current_device();
44 }
45 check_device(device);
46 c10::call_once(xpu_gens_init_flag[device], [&]() {
47 default_gens_xpu[device] = make_generator<XPUGeneratorImpl>(device);
48 default_gens_xpu[device].seed();
49 });
50 return default_gens_xpu[device];
51 }
52
53 // Create a generator with a fixed seed for a specific xpu device.
createXPUGenerator(DeviceIndex device)54 Generator createXPUGenerator(DeviceIndex device) {
55 c10::call_once(init_flag, initXPUGenVector);
56 if (device == -1) {
57 device = c10::xpu::current_device();
58 }
59 check_device(device);
60 auto gen = make_generator<XPUGeneratorImpl>(device);
61 auto xpu_gen = check_generator<XPUGeneratorImpl>(gen);
62 xpu_gen->set_current_seed(default_rng_seed_val);
63 xpu_gen->set_philox_offset_per_thread(0);
64 return gen;
65 }
66
67 } // namespace xpu::detail
68
XPUGeneratorImpl(DeviceIndex device_index)69 XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index)
70 : GeneratorImpl{
71 Device(DeviceType::XPU, device_index),
72 DispatchKeySet(c10::DispatchKey::XPU)} {}
73
set_current_seed(uint64_t seed)74 void XPUGeneratorImpl::set_current_seed(uint64_t seed) {
75 seed_ = seed;
76 set_philox_offset_per_thread(0);
77 }
78
set_offset(uint64_t offset)79 void XPUGeneratorImpl::set_offset(uint64_t offset) {
80 set_philox_offset_per_thread(offset);
81 }
82
get_offset() const83 uint64_t XPUGeneratorImpl::get_offset() const {
84 return philox_offset_per_thread_;
85 }
86
current_seed() const87 uint64_t XPUGeneratorImpl::current_seed() const {
88 return seed_;
89 }
90
seed()91 uint64_t XPUGeneratorImpl::seed() {
92 auto random = c10::detail::getNonDeterministicRandom(true);
93 this->set_current_seed(random);
94 return random;
95 }
96
get_state() const97 c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
98 // The RNG state comprises the seed, and an offset used for Philox.
99 static const size_t seed_size = sizeof(uint64_t);
100 static const size_t offset_size = sizeof(uint64_t);
101 static const size_t total_size = seed_size + offset_size;
102
103 // The internal state is returned as a CPU byte tensor.
104 auto state_tensor = at::detail::empty_cpu(
105 {static_cast<int64_t>(total_size)},
106 ScalarType::Byte,
107 std::nullopt,
108 std::nullopt,
109 std::nullopt,
110 std::nullopt);
111 auto rng_state = state_tensor.data_ptr<uint8_t>();
112 auto current_seed = this->current_seed();
113 auto offset = this->philox_offset_per_thread();
114 memcpy(rng_state, ¤t_seed, seed_size);
115 memcpy(rng_state + seed_size, &offset, offset_size);
116
117 return state_tensor.getIntrusivePtr();
118 }
119
set_state(const c10::TensorImpl & new_state)120 void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
121 static const size_t seed_size = sizeof(uint64_t);
122 static const size_t offset_size = sizeof(uint64_t);
123 static const size_t total_size = seed_size + offset_size;
124
125 at::detail::check_rng_state(new_state);
126 auto new_state_size = new_state.numel();
127 TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
128
129 uint64_t input_seed;
130 auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
131 memcpy(&input_seed, new_rng_state, seed_size);
132 this->set_current_seed(input_seed);
133 uint64_t philox_offset;
134 memcpy(&philox_offset, new_rng_state + seed_size, offset_size);
135 this->set_philox_offset_per_thread(philox_offset);
136 }
137
set_philox_offset_per_thread(uint64_t offset)138 void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
139 TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
140 philox_offset_per_thread_ = offset;
141 }
142
philox_offset_per_thread() const143 uint64_t XPUGeneratorImpl::philox_offset_per_thread() const {
144 return philox_offset_per_thread_;
145 }
146
philox_engine_inputs(uint64_t increment)147 std::pair<uint64_t, uint64_t> XPUGeneratorImpl::philox_engine_inputs(
148 uint64_t increment) {
149 increment = ((increment + 3) / 4) * 4;
150 TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
151 uint64_t offset = this->philox_offset_per_thread_;
152 this->philox_offset_per_thread_ += increment;
153 return std::make_pair(this->seed_, offset);
154 }
155
device_type()156 DeviceType XPUGeneratorImpl::device_type() {
157 return DeviceType::XPU;
158 }
159
clone() const160 std::shared_ptr<XPUGeneratorImpl> XPUGeneratorImpl::clone() const {
161 return std::shared_ptr<XPUGeneratorImpl>(this->clone_impl());
162 }
163
clone_impl() const164 XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const {
165 auto gen = new XPUGeneratorImpl(this->device().index());
166 gen->set_current_seed(this->seed_);
167 gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
168 return gen;
169 }
170
171 } // namespace at
172