xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/softmax_op_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/kernels/gpu_prim.h"
27 #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
28 #include "tensorflow/core/kernels/reduction_ops_common.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32 
33 namespace tensorflow {
34 
35 namespace {
36 
37 template <typename U, typename T>
38 __device__ __host__ EIGEN_STRONG_INLINE
39     typename std::enable_if<!std::is_same<T, U>::value, U>::type
40     strict_cast(T t);
41 
42 template <typename U, typename T>
43 __device__ __host__ EIGEN_STRONG_INLINE
44     typename std::enable_if<std::is_same<T, U>::value, U>::type
strict_cast(T t)45     strict_cast(T t) {
46   return t;
47 }
48 
49 template <>
strict_cast(Eigen::half t)50 __device__ __host__ EIGEN_STRONG_INLINE float strict_cast<float, Eigen::half>(
51     Eigen::half t) {
52   return functor::HalfToFloat()(t);
53 }
54 
55 template <>
56 __device__ __host__ EIGEN_STRONG_INLINE Eigen::half
strict_cast(float t)57 strict_cast<Eigen::half, float>(float t) {
58   return functor::FloatToHalf()(t);
59 }
60 
61 template <typename T>
62 struct softmax_traits {
63   using accumulator_type = T;
64 };
65 
66 template <>
67 struct softmax_traits<Eigen::half> {
68   using accumulator_type = float;
69 };
70 
71 template <typename T, typename U, int kUnroll>
GenerateNormalizedProb(const T * logits,const U * sum_probs,const T * max_logits,T * output,const int num_rows,const int num_cols,const bool in_log_space)72 __global__ void GenerateNormalizedProb(const T* logits, const U* sum_probs,
73                                        const T* max_logits, T* output,
74                                        const int num_rows, const int num_cols,
75                                        const bool in_log_space) {
76   int tid = blockIdx.x * blockDim.x + threadIdx.x;
77   int row, col;
78 
79   // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
80   U input[kUnroll];
81   U max_val[kUnroll];
82   U result[kUnroll];
83   for (int i = 0; i < kUnroll; i++) {
84     row = tid / num_cols;
85     col = tid % num_cols;
86     if (row < num_rows && col < num_cols) {
87       input[i] = strict_cast<U>(logits[tid]);
88       max_val[i] = strict_cast<U>(ldg(max_logits + row));
89     }
90     tid += gridDim.x * blockDim.x;
91   }
92 
93   tid = blockIdx.x * blockDim.x + threadIdx.x;
94   for (int i = 0; i < kUnroll; i++) {
95     row = tid / num_cols;
96     col = tid % num_cols;
97     if (row < num_rows && col < num_cols) {
98       if (in_log_space) {
99         result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row));
100       } else {
101         result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row);
102       }
103       output[tid] = strict_cast<T>(result[i]);
104     }
105     tid += gridDim.x * blockDim.x;
106   }
107 }
108 
109 template <>
GenerateNormalizedProb(const Eigen::half * logits,const float * sum_probs,const Eigen::half * max_logits,Eigen::half * output,const int num_rows,const int num_cols,const bool in_log_space)110 __global__ void GenerateNormalizedProb<Eigen::half, float, 8>(
111     const Eigen::half* logits, const float* sum_probs,
112     const Eigen::half* max_logits, Eigen::half* output, const int num_rows,
113     const int num_cols, const bool in_log_space) {
114   const int kUnroll = 8;
115   int tid = blockIdx.x * blockDim.x + threadIdx.x;
116   int idx[kUnroll];
117   int row[kUnroll];
118 
119   float input[kUnroll];
120   float max_val[kUnroll];
121   float result[kUnroll];
122 
123   if (tid * kUnroll + kUnroll - 1 < num_rows * num_cols) {
124     ulonglong2 logits_d =
125         *reinterpret_cast<const ulonglong2*>(logits + tid * kUnroll);
126     Eigen::half* logits_h = reinterpret_cast<Eigen::half*>(&logits_d);
127     ulonglong2 output_d;
128     Eigen::half* output_h = reinterpret_cast<Eigen::half*>(&output_d);
129 
130     for (int i = 0; i < kUnroll; i++) {
131       idx[i] = tid * kUnroll + i;
132       row[i] = idx[i] / num_cols;
133       input[i] = strict_cast<float>(logits_h[i]);
134       max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
135       if (in_log_space) {
136         result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
137       } else {
138         result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
139       }
140       output_h[i] = strict_cast<Eigen::half>(result[i]);
141     }
142 
143     *reinterpret_cast<ulonglong2*>(output + tid * kUnroll) = output_d;
144   } else {
145     for (int i = 0; i < kUnroll; i++) {
146       if (tid * kUnroll + i < num_rows * num_cols) {
147         idx[i] = tid * kUnroll + i;
148         row[i] = idx[i] / num_cols;
149         input[i] = strict_cast<float>(logits[idx[i]]);
150         max_val[i] = strict_cast<float>(ldg(max_logits + row[i]));
151         if (in_log_space) {
152           result[i] = input[i] - max_val[i] - log(ldg(sum_probs + row[i]));
153         } else {
154           result[i] = exp(input[i] - max_val[i]) / ldg(sum_probs + row[i]);
155         }
156         output[idx[i]] = strict_cast<Eigen::half>(result[i]);
157       }
158     }
159   }
160 }
161 
162 template <typename T, typename U>
163 struct SubtractAndExpFunctor {
SubtractAndExpFunctortensorflow::__anondc8c7deb0111::SubtractAndExpFunctor164   __host__ __device__ SubtractAndExpFunctor(const T* __restrict__ logits,
165                                             const T* __restrict__ max_logits,
166                                             const int num_cols)
167       : logits_(logits), max_logits_(max_logits), num_cols_(num_cols) {}
168 
operator ()tensorflow::__anondc8c7deb0111::SubtractAndExpFunctor169   __host__ __device__ U operator()(const int gid) const {
170     // TODO(jamesqin): change to half2 load when inputs are Eigen::half.
171     const U diff =
172         strict_cast<U>(logits_[gid] - ldg(max_logits_ + gid / num_cols_));
173     return exp(diff);
174   }
175 
176   const T* logits_;
177   const T* max_logits_;
178   const int num_cols_;
179 };
180 
181 template <typename T, typename Op, typename InputIter>
DoRowReduction(OpKernelContext * context,T * output,InputIter input,int rows,int cols)182 void DoRowReduction(OpKernelContext* context, T* output, InputIter input,
183                     int rows, int cols) {
184   typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
185   Constants<GPUDevice> constants;
186 
187   Op op;
188 
189   functor::ReduceImpl<T, Op, T*, InputIter, ReductionAxes>(
190       context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
191 }
192 }  // namespace
193 
194 template <typename T>
195 class SoftmaxOpGPU : public OpKernel {
196  public:
SoftmaxOpGPU(OpKernelConstruction * context)197   explicit SoftmaxOpGPU(OpKernelConstruction* context) : OpKernel(context) {
198     log_ = absl::StartsWith(type_string(), "Log");
199   }
200 
Compute(OpKernelContext * context)201   void Compute(OpKernelContext* context) override {
202     const Tensor& logits_in_ = context->input(0);
203     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in_.shape()),
204                 errors::InvalidArgument("logits must have >= 1 dimension, got ",
205                                         logits_in_.shape().DebugString()));
206     auto logits_in = logits_in_.flat_inner_dims<T>();
207     const int rows = logits_in.dimension(0);
208     const int cols = logits_in.dimension(1);
209     Tensor* softmax_out = nullptr;
210     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
211                                 {0}, 0, logits_in_.shape(), &softmax_out));
212 
213     const auto& cu_stream = GetGpuStream(context);
214     if (logits_in_.NumElements() > 0) {
215       Tensor max_logits;
216       Tensor sum_probs;
217       OP_REQUIRES_OK(context,
218                      context->allocate_temp(DataTypeToEnum<T>::value,
219                                             softmax_out->shape(), &max_logits));
220 
221       typedef typename softmax_traits<T>::accumulator_type acc_type;
222       OP_REQUIRES_OK(context,
223                      context->allocate_temp(DataTypeToEnum<acc_type>::value,
224                                             softmax_out->shape(), &sum_probs));
225 
226       DoRowReduction<T, gpuprim::Max, const T*>(
227           context, const_cast<T*>(max_logits.flat<T>().data()),
228           reinterpret_cast<const T*>(logits_in_.flat<T>().data()), rows, cols);
229 
230 
231       gpuprim::CountingInputIterator<int> counting_iterator(0);
232       using InputIterType =
233           gpuprim::TransformInputIterator<acc_type,
234                                           SubtractAndExpFunctor<T, acc_type>,
235                                           gpuprim::CountingInputIterator<int>>;
236 
237       InputIterType input_itr(
238           counting_iterator,
239           SubtractAndExpFunctor<T, acc_type>(
240               reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
241               reinterpret_cast<const T*>(max_logits.flat<T>().data()), cols));
242 
243       DoRowReduction<acc_type, gpuprim::Sum, InputIterType>(
244           context, const_cast<acc_type*>(sum_probs.flat<acc_type>().data()),
245           input_itr, rows, cols);
246 
247       auto in_ptr = reinterpret_cast<uintptr_t>(logits_in_.flat<T>().data());
248       auto out_ptr = reinterpret_cast<uintptr_t>(softmax_out->flat<T>().data());
249       bool aligned = in_ptr % 16 == 0 && out_ptr % 16 == 0;
250 
251       const int numThreadsPerBlock = 128;
252       if (DataTypeToEnum<T>::value == DT_HALF && aligned) {
253         const int kUnroll = 8;
254         const int numBlocks =
255             Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
256         TF_CHECK_OK(GpuLaunchKernel(
257             GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
258             numThreadsPerBlock, 0, cu_stream,
259             reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
260             reinterpret_cast<const acc_type*>(
261                 sum_probs.flat<acc_type>().data()),
262             reinterpret_cast<const T*>(max_logits.flat<T>().data()),
263             const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
264       } else {
265         const int kUnroll = 4;
266         const int numBlocks =
267             Eigen::divup(rows * cols, numThreadsPerBlock * kUnroll);
268         TF_CHECK_OK(GpuLaunchKernel(
269             GenerateNormalizedProb<T, acc_type, kUnroll>, numBlocks,
270             numThreadsPerBlock, 0, cu_stream,
271             reinterpret_cast<const T*>(logits_in_.flat<T>().data()),
272             reinterpret_cast<const acc_type*>(
273                 sum_probs.flat<acc_type>().data()),
274             reinterpret_cast<const T*>(max_logits.flat<T>().data()),
275             const_cast<T*>(softmax_out->flat<T>().data()), rows, cols, log_));
276       }
277     }
278   }
279 
280  private:
281   bool log_;
282 };
283 
284 #define REGISTER_GPU(T)                                          \
285   REGISTER_KERNEL_BUILDER(                                       \
286       Name("Softmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
287       SoftmaxOpGPU<T>);
288 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
289 
290 #undef REGISTER_GPU
291 #define REGISTER_GPU(T)                                             \
292   REGISTER_KERNEL_BUILDER(                                          \
293       Name("LogSoftmax").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
294       SoftmaxOpGPU<T>);
295 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
296 
297 #undef REGISTER_GPU
298 
299 }  // end namespace tensorflow
300 
301 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
302