1 #include <c10/core/GeneratorImpl.h>
2 #include <random>
3
4 #if defined(__SGX_ENABLED__)
5 #include <sgx_trts.h>
6 #endif
7
8 #ifndef _WIN32
9 #include <fcntl.h>
10 #include <unistd.h>
11 #else
12 #include <chrono>
13 #endif
14
15 namespace c10 {
16
17 /**
18 * GeneratorImpl class implementation
19 */
GeneratorImpl(Device device_in,DispatchKeySet key_set)20 GeneratorImpl::GeneratorImpl(Device device_in, DispatchKeySet key_set)
21 : device_{device_in}, key_set_(key_set) {}
22
23 /**
24 * Clone this generator. Note that clone() is the only
25 * method for copying for Generators in ATen.
26 */
clone() const27 c10::intrusive_ptr<GeneratorImpl> GeneratorImpl::clone() const {
28 auto res = this->clone_impl();
29 c10::raw::intrusive_ptr::incref(res);
30 c10::raw::weak_intrusive_ptr::incref(res);
31 return c10::intrusive_ptr<GeneratorImpl>::reclaim(res);
32 }
33
graphsafe_set_state(const c10::intrusive_ptr<c10::GeneratorImpl> & state)34 void GeneratorImpl::graphsafe_set_state(
35 const c10::intrusive_ptr<c10::GeneratorImpl>& state) {
36 TORCH_CHECK_NOT_IMPLEMENTED(
37 false, "graphsafe_set_state is not supported in this Generator");
38 }
39
graphsafe_get_state() const40 c10::intrusive_ptr<c10::GeneratorImpl> GeneratorImpl::graphsafe_get_state()
41 const {
42 TORCH_CHECK_NOT_IMPLEMENTED(
43 false, "graphsafe_get_state is not supported in this Generator");
44 }
45
46 /**
47 * Gets the device of a generator.
48 */
device() const49 Device GeneratorImpl::device() const {
50 return device_;
51 }
52
53 namespace detail {
54
55 /**
56 * Gets a random number for /dev/urandom
57 * Note this is a legacy method (from THRandom.cpp)
58 * FIXME: use std::random_device with entropy information
59 */
60 #if !defined(_WIN32)
readURandomLong()61 static uint64_t readURandomLong() {
62 int randDev = open("/dev/urandom", O_RDONLY);
63 TORCH_CHECK(randDev >= 0, "Unable to open /dev/urandom");
64 uint64_t randValue{};
65 ssize_t readBytes = read(randDev, &randValue, sizeof(randValue));
66 close(randDev);
67 TORCH_CHECK(
68 readBytes >= (ssize_t)sizeof(randValue),
69 "Unable to read from /dev/urandom");
70 return randValue;
71 }
72 #endif // _WIN32
73
74 /**
75 * Gets a non deterministic random number number from either the
76 * /dev/urandom or the current time. For CUDA, gets random from
77 * std::random_device and adds a transformation on it. For Intel SGX
78 * platform use sgx_read_rand as reading from /dev/urandom is
79 * prohibited on that platform.
80 *
81 * FIXME: The behavior in this function is from legacy code
82 * (THRandom_seed/THCRandom_seed) and is probably not the right thing to do,
83 * even though our tests pass. Figure out if tests get perturbed
84 * - when the same algorithm is used for all backends. Note that the current
85 * behavior is different for CPU, CUDA and Windows CPU.
86 * - when using C++11 std objects, such as std::random_device
87 * - when constructing a 64 bit seed properly, rather than static casting
88 * a 32 bit number to 64 bit.
89 */
getNonDeterministicRandom(bool is_cuda)90 uint64_t getNonDeterministicRandom(bool is_cuda) {
91 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
92 uint64_t s;
93 if (!is_cuda) {
94 #ifdef _WIN32
95 s = (uint64_t)std::chrono::high_resolution_clock::now()
96 .time_since_epoch()
97 .count();
98 #elif defined(__SGX_ENABLED__)
99 TORCH_CHECK(
100 sgx_read_rand(reinterpret_cast<uint8_t*>(&s), sizeof(s)) == SGX_SUCCESS,
101 "Could not generate random number with sgx_read_rand.");
102 #else
103 s = readURandomLong();
104 #endif
105 } else {
106 std::random_device rd;
107 // limit to 53 bits to ensure unique representation in double
108 s = ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
109 }
110 return s;
111 }
112
113 } // namespace detail
114 } // namespace c10
115