1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
18
19 #include "tensorflow/core/lib/random/philox_random.h"
20 #include "tensorflow/core/platform/types.h"
21
22 namespace tensorflow {
23
24 using random::PhiloxRandom;
25
26 // The following 2 functions use the contract "lower 32 bits for the first
27 // uint32, higher 32 bits for the second". Note that this is endian-neutral,
28 // unlike a direct memory copy `memcpy(output, &input, 8)`.
Uint64ToUint32s(uint64 input,uint32 * output1,uint32 * output2)29 PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1,
30 uint32* output2) {
31 *output1 = static_cast<uint32>(input);
32 *output2 = static_cast<uint32>(input >> 32);
33 }
34
Uint32sToUint64(uint32 input1,uint32 input2)35 PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) {
36 auto u64_1 = static_cast<uint64>(input1);
37 auto u64_2 = static_cast<uint64>(input2);
38 return u64_1 | (u64_2 << 32);
39 }
40
GetCounterFromMem(uint64 const * ptr)41 PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem(
42 uint64 const* ptr) {
43 PhiloxRandom::ResultType counter;
44 Uint64ToUint32s(ptr[0], &counter[0], &counter[1]);
45 Uint64ToUint32s(ptr[1], &counter[2], &counter[3]);
46 return counter;
47 }
48
WriteCounterToMem(PhiloxRandom::ResultType const & counter,uint64 * ptr)49 PHILOX_DEVICE_INLINE void WriteCounterToMem(
50 PhiloxRandom::ResultType const& counter, uint64* ptr) {
51 ptr[0] = Uint32sToUint64(counter[0], counter[1]);
52 ptr[1] = Uint32sToUint64(counter[2], counter[3]);
53 }
54
GetKeyFromMem(uint64 const * ptr)55 PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) {
56 PhiloxRandom::Key key;
57 Uint64ToUint32s(ptr[0], &key[0], &key[1]);
58 return key;
59 }
60
WriteKeyToMem(PhiloxRandom::Key const & key,uint64 * ptr)61 PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key,
62 uint64* ptr) {
63 *ptr = Uint32sToUint64(key[0], key[1]);
64 }
65
GetPhiloxRandomFromCounterKeyMem(uint64 const * counter_ptr,uint64 const * key_ptr)66 PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem(
67 uint64 const* counter_ptr, uint64 const* key_ptr) {
68 return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr));
69 }
70
71 } // end namespace tensorflow
72
73 #endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
74