1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/kernels/gpu_prim.h"
27 #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
28 #include "tensorflow/core/kernels/reduction_ops_common.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32
33 namespace tensorflow {
34
35 namespace {
36
37 template <typename U, typename T>
38 __device__ __host__ EIGEN_STRONG_INLINE
39 typename std::enable_if<!std::is_same<T, U>::value, U>::type
40 strict_cast(T t);
41
42 template <typename U, typename T>
43 __device__ __host__ EIGEN_STRONG_INLINE
44 typename std::enable_if<std::is_same<T, U>::value, U>::type
strict_cast(T t)45 strict_cast(T t) {
46 return t;
47 }
48
49 template <>
strict_cast(Eigen::half t)50 __device__ __host__ EIGEN_STRONG_INLINE float strict_cast<float, Eigen::half>(
51 Eigen::half t) {
52 return functor::HalfToFloat()(t);
53 }
54
55 template <>
56 __device__ __host__ EIGEN_STRONG_INLINE Eigen::half
strict_cast(float t)57 strict_cast<Eigen::half, float>(float t) {
58 return functor::FloatToHalf()(t);
59 }
60
61 template <typename T>
62 struct softmax_traits {
63 using accumulator_type = T;
64 };
65
66 template <>
67 struct softmax_traits<Eigen::half> {
68 using accumulator_type = float;
69 };
70
71 template <typename T, typename U, int kUnroll>
GenerateNormalizedProb(const T * logits,const U * sum_probs,const T * max_logits,T * output,const int num_rows,const int num_cols,const bool in_log_space)72 __global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs,
73 const T* max_logits, T* output,
74 const int num_rows, const int num_cols,
75 const bool in_log_space) {
76 int tid = blockIdx.x * blockDim.x + threadIdx.x;
77 int row, col;
78
79 // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
80 U input[kUnroll];
81 U max_val[kUnroll];
82 U result[kUnroll];
83 for (int i = 0; i < kUnroll; i++) {
84 row = tid / num_cols;
85 col = tid % num_cols;
86 if (row < num_rows && col < num_cols) {
87 input[i] = strict_cast<U>(logits[tid]);
88 max_val[i] = strict_cast<U>(ldg(max_logits + row));
89 }
90 tid += gridDim.x * blockDim.x;
91 }
92
93 tid = blockIdx.x * blockDim.x + threadIdx.x;
94 for (int i = 0; i < kUnroll; i++) {
95 row = tid / num_cols;
96 col = tid % num_cols;
97 if (row < num_rows && col < num_cols) {
98 if (in_log_space) {
99 result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row));
100 } else {
101 result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row);
102 }
103 output[tid] = strict_cast<T>(result[i]);
104 }
105 tid += gridDim.x * blockDim.x;
106 }
107 }
108
109 template <>
GenerateNormalizedProb(const Eigen::half * logits,const float * sum_probs,const Eigen::half * max_logits,Eigen::half * output,const int num_rows,const int num_cols,const bool in_log_space)110 __global__ void GenerateNormalizedProb<Eigen::half, float, 8>(
111 const Eigen::half* logits, const float* sum_probs,
112 const Eigen::half* max_logits, Eigen::half* output, const int num_rows,
113 const int num_cols, const bool in_log_space) {
114 const int kUnroll = 8;
115 int tid = blockIdx.x * blockDim.x + threadIdx.x;
116 int idx[kUnroll];
117 int row[kUnroll];
118
119 float input[kUnroll];
120 float max_val[kUnroll];
121 float result[kUnroll];
122
123 if (tid * kUnroll + kUnroll - 1 < num_rows * num_cols) {
124 ulonglong2 logits_d =
125 *reinterpret_cast<const ulonglong2*>(logits + tid * kUnroll);
126 Eigen::half* logits_h = reinterpret_cast<Eigen::half*>(&logits_d);
127 ulonglong2 output_d;
128 Eigen::half* output_h = reinterpret_cast<Eigen::half*>(&output_d);
129
130 for (int i = 0; i < kUnroll; i++) {
131 idx[i] = tid * kUnroll + i;
132 row[i] = idx[i] / num_cols;
133 input[i] = strict_cast<float>(logits_h[i]);
134 max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
135 if (in_log_space) {
136 result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
137 } else {
138 result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
139 }
140 output_h[i] = strict_cast<Eigen::half>(result[i]);
141 }
142
143 *reinterpret_cast<ulonglong2*>(output + tid * kUnroll) = output_d;
144 } else {
145 for (int i = 0; i < kUnroll; i++) {
146 if (tid * kUnroll + i < num_rows * num_cols) {
147 idx[i] = tid * kUnroll + i;
148 row[i] = idx[i] / num_cols;
149 input[i] = strict_cast<float>(logits[idx[i]]);
150 max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
151 if (in_log_space) {
152 result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
153 } else {
154 result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
155 }
156 output[idx[i]] = strict_cast<Eigen::half>(result[i]);
157 }
158 }
159 }
160 }
161
162 template <typename T, typename U>
163 struct SubtractAndExpFunctor {
SubtractAndExpFunctortensorflow::__anondc8c7deb0111::SubtractAndExpFunctor164 __host__ __device__ SubtractAndExpFunctor(const T* __restrict__ logits,
165 const T* __restrict__ max_logits,
166 const int num_cols)
167 : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {}
168
operator ()tensorflow::__anondc8c7deb0111::SubtractAndExpFunctor169 __host__ __device__ U operator()(const int gid) const {
170 // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
171 const U diff =
172 strict_cast<U>(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
173 return exp(diff);
174 }
175
176 const T* logits_;
177 const T* max_logits_;
178 const int num_cols_;
179 };
180
181 template <typename T, typename Op, typename InputIter>
DoRowReduction(OpKernelContext * context,T * output,InputIter input,int rows,int cols)182 void DoRowReduction(OpKernelContext* context, T* output, InputIter input,
183 int rows, int cols) {
184 typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
185 Constants<GPUDevice> constants;
186
187 Op op;
188
189 functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>(
190 context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
191 }
192 } // namespace
193
194 template <typename T>
195 class SoftmaxOpGPU : public OpKernel {
196 public:
SoftmaxOpGPU(OpKernelConstruction * context)197 explicit SoftmaxOpGPU(OpKernelConstruction* context) : OpKernel(context) {
198 log_ = absl::StartsWith(type_string(), "Log");
199 }
200
Compute(OpKernelContext * context)201 void Compute(OpKernelContext* context) override {
202 const Tensor& logits_in_ = context->input(0);
203 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in_.shape()),
204 errors::InvalidArgument("logits must have >= 1 dimension, got ",
205 logits_in_.shape().DebugString()));
206 auto logits_in = logits_in_.flat_inner_dims<T>();
207 const int rows = logits_in.dimension(0);
208 const int cols = logits_in.dimension(1);
209 Tensor* softmax_out = nullptr;
210 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
211 {0}, 0, logits_in_.shape(), &softmax_out));
212
213 const auto& cu_stream = GetGpuStream(context);
214 if (logits_in_.NumElements() > 0) {
215 Tensor max_logits;
216 Tensor sum_probs;
217 OP_REQUIRES_OK(context,
218 context->allocate_temp(DataTypeToEnum<T>::value,
219 softmax_out->shape(), &max_logits));
220
221 typedef typename softmax_traits<T>::accumulator_type acc_type;
222 OP_REQUIRES_OK(context,
223 context->allocate_temp(DataTypeToEnum<acc_type>::value,
224 softmax_out->shape(), &sum_probs));
225
226 DoRowReduction<T, gpuprim::Max, const T*>(
227 context, const_cast<T*>(max_logits.flat<T>().data()),
228 reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols);
229
230
231 gpuprim::CountingInputIterator<int> counting_iterator(0);
232 using InputIterType =
233 gpuprim::TransformInputIterator<acc_type,
234 SubtractAndExpFunctor<T, acc_type>,
235 gpuprim::CountingInputIterator<int>>;
236
237 InputIterType input_itr(
238 counting_iterator,
239 SubtractAndExpFunctor<T, acc_type>(
240 reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
241 reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols));
242
243 DoRowReduction<acc_type, gpuprim::Sum, InputIterType>(
244 context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()),
245 input_itr, rows, cols);
246
247 auto in_ptr = reinterpret_cast<uintptr_t>(logits_in_.flat<T>().data());
248 auto out_ptr = reinterpret_cast<uintptr_t>(softmax_out->flat<T>().data());
249 bool aligned = in_ptr % 16 == 0 && out_ptr % 16 == 0;
250
251 const int numThreadsPerBlock = 128;
252 if (DataTypeToEnum<T>::value == DT_HALF && aligned) {
253 const int kUnroll = 8;
254 const int numBlocks =
255 Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
256 TF_CHECK_OK(GpuLaunchKernel(
257 GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
258 numThreadsPerBlock, 0, cu_stream,
259 reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
260 reinterpret_cast<const acc_type*>(
261 sum_probs.flat<acc_type>().data()),
262 reinterpret_cast<const T*>(max_logits.flat<T>().data()),
263 const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
264 } else {
265 const int kUnroll = 4;
266 const int numBlocks =
267 Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
268 TF_CHECK_OK(GpuLaunchKernel(
269 GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
270 numThreadsPerBlock, 0, cu_stream,
271 reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
272 reinterpret_cast<const acc_type*>(
273 sum_probs.flat<acc_type>().data()),
274 reinterpret_cast<const T*>(max_logits.flat<T>().data()),
275 const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
276 }
277 }
278 }
279
280 private:
281 bool log_;
282 };
283
284 #define REGISTER_GPU(T) \
285 REGISTER_KERNEL_BUILDER( \
286 Name("Softmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
287 SoftmaxOpGPU<T>);
288 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
289
290 #undef REGISTER_GPU
291 #define REGISTER_GPU(T) \
292 REGISTER_KERNEL_BUILDER( \
293 Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
294 SoftmaxOpGPU<T>);
295 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
296
297 #undef REGISTER_GPU
298
299 } // end namespace tensorflow
300
301 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
302