1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
17 #define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
18
19 #include "tensorflow/core/kernels/random_ops_util.h"
20 #include "tensorflow/core/kernels/stateful_random_ops.h"
21
22 namespace tensorflow {
23
24 PHILOX_DEVICE_INLINE PhiloxRandom
GetPhiloxRandomFromMem(StateElementType const * ptr)25 GetPhiloxRandomFromMem(StateElementType const* ptr) {
26 auto ptr_ = reinterpret_cast<uint64 const*>(ptr);
27 return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2);
28 }
29
WritePhiloxRandomToMem(PhiloxRandom const & philox,StateElementType * ptr)30 PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
31 StateElementType* ptr) {
32 auto ptr_ = reinterpret_cast<uint64*>(ptr);
33 WriteCounterToMem(philox.counter(), ptr_);
34 WriteKeyToMem(philox.key(), ptr_ + 2);
35 }
36
SkipPhiloxRandom(PhiloxRandom const & philox,uint64 output_size)37 PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox,
38 uint64 output_size) {
39 auto new_philox = philox;
40 // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it
41 // just here.
42 auto delta = output_size * 256;
43 new_philox.Skip(delta); // do the actual increasing
44 return new_philox;
45 }
46
UpdateMemWithPhiloxRandom(PhiloxRandom const & philox,uint64 output_size,StateElementType * ptr)47 PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
48 uint64 output_size,
49 StateElementType* ptr) {
50 auto new_philox = SkipPhiloxRandom(philox, output_size);
51 WritePhiloxRandomToMem(new_philox, ptr);
52 }
53
UpdateCounterMemWithPhiloxRandom(PhiloxRandom::ResultType const & counter,uint64 output_size,StateElementType * ptr)54 PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom(
55 PhiloxRandom::ResultType const& counter, uint64 output_size,
56 StateElementType* ptr) {
57 auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/);
58 auto new_philox = SkipPhiloxRandom(philox, output_size);
59 WriteCounterToMem(new_philox.counter(), reinterpret_cast<uint64*>(ptr));
60 }
61
62 namespace functor {
63
64 // A per-device helper function that does the actual work for
65 // `UpdateVariableAndFill`.
66 // Reason to use functor: C++ doesn't allow function-template partial
67 // specialization.
68 template <typename Device, typename Distribution>
69 struct UpdateVariableAndFill_Philox;
70
71 template <typename Device>
72 struct RngSkip_Philox;
73
74 } // end namespace functor
75
76 using CPUDevice = Eigen::ThreadPoolDevice;
77
78 class ScopedUnlockUnrefVar;
79
80 struct UpdateVariableAndFill_Philox_Arg {
81 int64_t output_size;
82 int64_t alg_tag_skip;
83 ScopedUnlockUnrefVar* state_var_guard;
84 Tensor* state_tensor;
85 };
86
87 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88
89 using GPUDevice = Eigen::GpuDevice;
90
91 namespace functor {
92
93 // Declares the partially GPU-specialized functor structs.
94 // must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
95 template <typename Distribution>
96 struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
97 void operator()(OpKernelContext* ctx, const GPUDevice& device,
98 Distribution dist, UpdateVariableAndFill_Philox_Arg* arg,
99 typename Distribution::ResultElementType* output_data);
100 };
101
102 template <>
103 struct RngSkip_Philox<GPUDevice> {
104 void operator()(const GPUDevice& device, const StateElementType* in_data,
105 uint64 delta, StateElementType* out_data);
106 };
107
108 } // end namespace functor
109
110 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
111
112 } // end namespace tensorflow
113
114 #endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
115