xref: /aosp_15_r20/external/pytorch/c10/core/GeneratorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/GeneratorImpl.h>
2 #include <random>
3 
4 #if defined(__SGX_ENABLED__)
5 #include <sgx_trts.h>
6 #endif
7 
8 #ifndef _WIN32
9 #include <fcntl.h>
10 #include <unistd.h>
11 #else
12 #include <chrono>
13 #endif
14 
15 namespace c10 {
16 
17 /**
18  * GeneratorImpl class implementation
19  */
GeneratorImpl(Device device_in,DispatchKeySet key_set)20 GeneratorImpl::GeneratorImpl(Device device_in, DispatchKeySet key_set)
21     : device_{device_in}, key_set_(key_set) {}
22 
23 /**
24  * Clone this generator. Note that clone() is the only
25  * method for copying for Generators in ATen.
26  */
clone() const27 c10::intrusive_ptr<GeneratorImpl> GeneratorImpl::clone() const {
28   auto res = this->clone_impl();
29   c10::raw::intrusive_ptr::incref(res);
30   c10::raw::weak_intrusive_ptr::incref(res);
31   return c10::intrusive_ptr<GeneratorImpl>::reclaim(res);
32 }
33 
graphsafe_set_state(const c10::intrusive_ptr<c10::GeneratorImpl> & state)34 void GeneratorImpl::graphsafe_set_state(
35     const c10::intrusive_ptr<c10::GeneratorImpl>& state) {
36   TORCH_CHECK_NOT_IMPLEMENTED(
37       false, "graphsafe_set_state is not supported in this Generator");
38 }
39 
graphsafe_get_state() const40 c10::intrusive_ptr<c10::GeneratorImpl> GeneratorImpl::graphsafe_get_state()
41     const {
42   TORCH_CHECK_NOT_IMPLEMENTED(
43       false, "graphsafe_get_state is not supported in this Generator");
44 }
45 
46 /**
47  * Gets the device of a generator.
48  */
device() const49 Device GeneratorImpl::device() const {
50   return device_;
51 }
52 
53 namespace detail {
54 
55 /**
56  * Gets a random number for /dev/urandom
57  * Note this is a legacy method (from THRandom.cpp)
58  * FIXME: use std::random_device with entropy information
59  */
60 #if !defined(_WIN32)
readURandomLong()61 static uint64_t readURandomLong() {
62   int randDev = open("/dev/urandom", O_RDONLY);
63   TORCH_CHECK(randDev >= 0, "Unable to open /dev/urandom");
64   uint64_t randValue{};
65   ssize_t readBytes = read(randDev, &randValue, sizeof(randValue));
66   close(randDev);
67   TORCH_CHECK(
68       readBytes >= (ssize_t)sizeof(randValue),
69       "Unable to read from /dev/urandom");
70   return randValue;
71 }
72 #endif // _WIN32
73 
74 /**
75  * Gets a non deterministic random number number from either the
76  * /dev/urandom or the current time. For CUDA, gets random from
77  * std::random_device and adds a transformation on it. For Intel SGX
78  * platform use sgx_read_rand as reading from /dev/urandom is
79  * prohibited on that platform.
80  *
81  * FIXME: The behavior in this function is from legacy code
82  * (THRandom_seed/THCRandom_seed) and is probably not the right thing to do,
83  * even though our tests pass. Figure out if tests get perturbed
84  * - when the same algorithm is used for all backends. Note that the current
85  * behavior is different for CPU, CUDA and Windows CPU.
86  * - when using C++11 std objects, such as std::random_device
87  * - when constructing a 64 bit seed properly, rather than static casting
88  *   a 32 bit number to 64 bit.
89  */
getNonDeterministicRandom(bool is_cuda)90 uint64_t getNonDeterministicRandom(bool is_cuda) {
91   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
92   uint64_t s;
93   if (!is_cuda) {
94 #ifdef _WIN32
95     s = (uint64_t)std::chrono::high_resolution_clock::now()
96             .time_since_epoch()
97             .count();
98 #elif defined(__SGX_ENABLED__)
99     TORCH_CHECK(
100         sgx_read_rand(reinterpret_cast<uint8_t*>(&s), sizeof(s)) == SGX_SUCCESS,
101         "Could not generate random number with sgx_read_rand.");
102 #else
103     s = readURandomLong();
104 #endif
105   } else {
106     std::random_device rd;
107     // limit to 53 bits to ensure unique representation in double
108     s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
109   }
110   return s;
111 }
112 
113 } // namespace detail
114 } // namespace c10
115