xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/fused_eigen_output_kernels.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Output kernels for fusing computation into Eigen Tensor contractions:
17 //   (1) FusedConv2DOp
18 //   (2) FusedMatMulOp
19 //
20 // Supported fused computations:
21 //   (1) {Conv2D/MatMul} + BiasAdd + <Activation>
22 //   (2) {Conv2D/MatMul} + FusedBatchNorm + <Activation>
23 //
24 // Activation: Relu, Relu6, Elu, etc...
25 
26 #ifndef TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
27 #define TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
28 
29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 
34 namespace tensorflow {
35 
36 enum class FusedComputationType {
37   kUndefined,
38   kBiasAdd,
39   kBiasAddWithRelu,
40   kBiasAddWithRelu6,
41   kBiasAddWithElu,
42   kBiasAddWithLeakyRelu,
43   kBiasAddWithGeluApproximate,
44   kFusedBatchNorm,
45   kFusedBatchNormWithRelu,
46   kFusedBatchNormWithRelu6,
47   kFusedBatchNormWithElu,
48   kFusedBatchNormWithLeakyRelu
49 };
50 
51 // We have to pass around additional arguments for all possible fusion types.
52 struct FusedComputationArgs {
53   float epsilon = 0.0;          // Used by `FusedBatchNorm` fusion only
54   float leakyrelu_alpha = 0.0;  // Used by `LeakyRelu` fusion only
55 };
56 
57 struct FusedComputationPattern {
58   FusedComputationType fused_computation;
59   std::vector<string> fused_ops;
60 };
61 
62 // Parse attributes from the kernel construction context, and verifies that they
63 // specify valid fused computation pattern.
64 Status InitializeFusedComputation(
65     OpKernelConstruction* context, const string& kernel_name,
66     const std::vector<FusedComputationPattern>& patterns,
67     FusedComputationType* fused_computation,
68     FusedComputationArgs* fused_computation_args);
69 
70 // Type alias for the tensor contraction output mapper.
71 template <typename Scalar, typename StorageIndex>
72 using ContractionOutputMapper =
73     Eigen::internal::blas_data_mapper<Scalar, StorageIndex, Eigen::ColMajor>;
74 
75 // Returns input expression without any transformations.
76 struct Identity {
77   template <typename XprType>
78   static auto apply(XprType expr) -> XprType {
79     return expr;
80   };
81 };
82 
83 // Applies `Relu` to the passed input expression.
84 struct Relu {
85   template <typename XprType>
86   static auto apply(XprType expr)
87       -> decltype(expr.cwiseMax(std::declval<typename XprType::Scalar>())) {
88     return expr.cwiseMax(static_cast<typename XprType::Scalar>(0));
89   };
90 };
91 
92 // Applies `Relu6` to the passed input expression.
93 struct Relu6 {
94   template <typename XprType>
95   static auto apply(XprType expr)
96       -> decltype(expr.cwiseMax(std::declval<typename XprType::Scalar>())
97                       .cwiseMin(std::declval<typename XprType::Scalar>())) {
98     return expr.cwiseMax(static_cast<typename XprType::Scalar>(0))
99         .cwiseMin(static_cast<typename XprType::Scalar>(6));
100   };
101 };
102 
103 // Applies `Elu` to the passed input expression.
104 struct Elu {
105   template <typename XprType>
106   static auto apply(XprType expr) -> decltype(
107       (expr < std::declval<typename XprType::Scalar>())
108           .select(expr.exp() -
109                       expr.constant(std::declval<typename XprType::Scalar>()),
110                   expr)) {
111     return (expr < static_cast<typename XprType::Scalar>(0))
112         .select(expr.exp() -
113                     expr.constant(static_cast<typename XprType::Scalar>(1)),
114                 expr);
115   };
116 };
117 
118 // Applies `LeakyRelu` to the passed input expression.
119 struct LeakyRelu {
120   template <typename XprType>
121   static auto apply(XprType expr, const float leakyrelu_alpha) -> decltype(
122       (expr < std::declval<typename XprType::Scalar>())
123           .select(expr *
124                       expr.constant(std::declval<typename XprType::Scalar>()),
125                   expr)) {
126     return (expr < static_cast<typename XprType::Scalar>(0))
127         .select(expr * expr.constant(static_cast<typename XprType::Scalar>(
128                            leakyrelu_alpha)),
129                 expr);
130   };
131 };
132 
133 template <typename T>
134 struct BiasAddArgs {
135   const T* bias_add_data = nullptr;
136   float leakyrelu_alpha;
137 
IsSupportedBiasAddArgs138   static bool IsSupported(FusedComputationType fusion) {
139     return fusion == FusedComputationType::kBiasAdd ||
140            fusion == FusedComputationType::kBiasAddWithRelu ||
141            fusion == FusedComputationType::kBiasAddWithRelu6 ||
142            fusion == FusedComputationType::kBiasAddWithElu ||
143            fusion == FusedComputationType::kBiasAddWithLeakyRelu;
144   }
145 };
146 
147 template <typename T>
148 struct FusedBatchNormArgs {
149   const T* scale_data = nullptr;
150   const T* offset_data = nullptr;
151   const T* estimated_mean_data = nullptr;
152   const T* estimated_variance_data = nullptr;
153 
154   // Precomputed expression:
155   //   scaling_factor = (estimated_variance + epsilon).rsqrt() * scale
156   Eigen::Tensor<T, 1, Eigen::RowMajor> scaling_factor;
157 
158   float leakyrelu_alpha;
159 
IsSupportedFusedBatchNormArgs160   static bool IsSupported(FusedComputationType fusion) {
161     return fusion == FusedComputationType::kFusedBatchNorm ||
162            fusion == FusedComputationType::kFusedBatchNormWithRelu ||
163            fusion == FusedComputationType::kFusedBatchNormWithRelu6 ||
164            fusion == FusedComputationType::kFusedBatchNormWithElu ||
165            fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu;
166   }
167 };
168 
169 // TensorContraction swaps lhs with rhs, and changes layout from RowMajor
170 // (default in Tensorflow) to ColMajor (preferred in Eigen), and computes matmul
171 // using these tensors.
172 //
173 // (1) Spatial Convolution (see eigen_spatial_convolutions.h):
174 //
175 //   TensorContraction output matrix (before reshape) has a ColMajor layout, and
176 //   has dimensions:
177 //   - rows: output_channels
178 //   - cols: all other dimensions
179 //
180 //   First element in every column is:
181 //     [batch ??, height ??, width ??, out_channel = i]
182 //
183 //   We do not know what are the values of the 'batch', 'height', and 'width'
184 //   here (if we know original dimensions, they can be computed from 'j').
185 //
186 //   Each column of an output block is a continuous slice along the output
187 //   channel dimension, so we can use it to efficiently compute any
188 //   transformation that depends only on a channel value (e.g. add channel
189 //   bias).
190 //
191 // (2) Matrix Multiplication (see matmul_op.cc):
192 //
193 //   For the `MxK * KxN` matrix multiplication, output matrix has a `MxN`
194 //   dimensions. Each column in output block is a slice of the innermost
195 //   dimension of the output matrix starting at offset 'i'.
196 //
197 //   Example: In Tensorflow MatMul [8x32] * [32x64], each output block column
198 //   will correspond to MatMul output row of size 64 (because Tensorflow uses
199 //   row major storage order).
200 
201 // Output kernel that fuses BiasAdd operation into the output of tensor
202 // contraction + activation function defined by Activation.
203 template <typename T, typename Activation = Identity>
204 struct BiasAddOutputKernel {
BiasAddOutputKernelBiasAddOutputKernel205   explicit BiasAddOutputKernel(const BiasAddArgs<T>& args)
206       : bias_data(args.bias_add_data) {}
207 
208   template <typename StorageIndex, typename Scalar>
operatorBiasAddOutputKernel209   EIGEN_ALWAYS_INLINE void operator()(
210       const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
211       const Eigen::TensorContractionParams& params, StorageIndex i,
212       StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
213     DCHECK(params.swapped_arguments);
214 
215     const T* bias_base = bias_data + i;
216     typename TTypes<T>::UnalignedConstTensor bias(bias_base, num_rows);
217 
218     for (int col = 0; col < num_cols; ++col) {
219       T* output_base = &output_mapper(0, col);
220       typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
221       const auto expr = output + bias;
222       output = Activation::template apply<decltype(expr)>(expr);
223     }
224   }
225 
226  private:
227   const T* bias_data;
228 };
229 
230 template <typename T>
231 struct BiasAddOutputKernel<T, LeakyRelu> {
232   explicit BiasAddOutputKernel(const BiasAddArgs<T>& args)
233       : bias_data(args.bias_add_data), leakyrelu_alpha(args.leakyrelu_alpha) {}
234 
235   template <typename StorageIndex, typename Scalar>
236   EIGEN_ALWAYS_INLINE void operator()(
237       const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
238       const Eigen::TensorContractionParams& params, StorageIndex i,
239       StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
240     DCHECK(params.swapped_arguments);
241 
242     const T* bias_base = bias_data + i;
243     typename TTypes<T>::UnalignedConstTensor bias(bias_base, num_rows);
244 
245     for (int col = 0; col < num_cols; ++col) {
246       T* output_base = &output_mapper(0, col);
247       typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
248       const auto expr = output + bias;
249       output = LeakyRelu::template apply<decltype(expr)>(expr, leakyrelu_alpha);
250     }
251   }
252 
253  private:
254   const T* bias_data;
255   float leakyrelu_alpha;
256 };
257 
258 // Output kernel that fuses FusedBatchNorm operation into the output of tensor
259 // contraction + activation function defined by Activation.
260 template <typename T, typename Activation = Identity>
261 struct FusedBatchNormOutputKernel {
262   FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs<T>& args)
263       : epsilon(epsilon),
264         scaling_factor_data(args.scaling_factor.data()),
265         offset_data(args.offset_data),
266         estimated_mean_data(args.estimated_mean_data) {}
267 
268   template <typename StorageIndex, typename Scalar>
269   EIGEN_ALWAYS_INLINE void operator()(
270       const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
271       const Eigen::TensorContractionParams& params, StorageIndex i,
272       StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
273     DCHECK(params.swapped_arguments);
274 
275     const T* scaling_factor_base = scaling_factor_data + i;
276     const T* offset_base = offset_data + i;
277     const T* mean_base = estimated_mean_data + i;
278 
279     typename TTypes<T>::UnalignedConstTensor scaling_factor(scaling_factor_base,
280                                                             num_rows);
281     typename TTypes<T>::UnalignedConstTensor offset(offset_base, num_rows);
282     typename TTypes<T>::UnalignedConstTensor mean(mean_base, num_rows);
283 
284     for (int col = 0; col < num_cols; ++col) {
285       T* output_base = &output_mapper(0, col);
286       typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
287 
288       auto scaled = (output - mean) * scaling_factor;
289       auto shifted = scaled + offset;
290 
291       output = Activation::template apply<decltype(shifted)>(shifted);
292     }
293   }
294 
295  private:
296   T epsilon;
297   const T* scaling_factor_data;
298   const T* offset_data;
299   const T* estimated_mean_data;
300 };
301 
302 template <typename T>
303 struct FusedBatchNormOutputKernel<T, LeakyRelu> {
304   FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs<T>& args)
305       : epsilon(epsilon),
306         scaling_factor_data(args.scaling_factor.data()),
307         offset_data(args.offset_data),
308         estimated_mean_data(args.estimated_mean_data),
309         leakyrelu_alpha(args.leakyrelu_alpha) {}
310 
311   template <typename StorageIndex, typename Scalar>
312   EIGEN_ALWAYS_INLINE void operator()(
313       const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
314       const Eigen::TensorContractionParams& params, StorageIndex i,
315       StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
316     DCHECK(params.swapped_arguments);
317 
318     const T* scaling_factor_base = scaling_factor_data + i;
319     const T* offset_base = offset_data + i;
320     const T* mean_base = estimated_mean_data + i;
321 
322     typename TTypes<T>::UnalignedConstTensor scaling_factor(scaling_factor_base,
323                                                             num_rows);
324     typename TTypes<T>::UnalignedConstTensor offset(offset_base, num_rows);
325     typename TTypes<T>::UnalignedConstTensor mean(mean_base, num_rows);
326 
327     for (int col = 0; col < num_cols; ++col) {
328       T* output_base = &output_mapper(0, col);
329       typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
330 
331       auto scaled = (output - mean) * scaling_factor;
332       auto shifted = scaled + offset;
333 
334       output = LeakyRelu::template apply<decltype(shifted)>(shifted,
335                                                             leakyrelu_alpha);
336     }
337   }
338 
339  private:
340   T epsilon;
341   const T* scaling_factor_data;
342   const T* offset_data;
343   const T* estimated_mean_data;
344   float leakyrelu_alpha;
345 };
346 
347 // Type aliases for the output kernels, purely for the sake of better launch
348 // dispatching code readability.
349 template <typename T>
350 using WithBiasAdd = BiasAddOutputKernel<T>;
351 template <typename T>
352 using WithBiasAddAndRelu = BiasAddOutputKernel<T, Relu>;
353 template <typename T>
354 using WithBiasAddAndRelu6 = BiasAddOutputKernel<T, Relu6>;
355 template <typename T>
356 using WithBiasAddAndElu = BiasAddOutputKernel<T, Elu>;
357 template <typename T>
358 using WithBiasAddAndLeakyRelu = BiasAddOutputKernel<T, LeakyRelu>;
359 template <typename T>
360 using WithFusedBatchNorm = FusedBatchNormOutputKernel<T>;
361 template <typename T>
362 using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel<T, Relu>;
363 template <typename T>
364 using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel<T, Relu6>;
365 template <typename T>
366 using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel<T, Elu>;
367 template <typename T>
368 using WithFusedBatchNormAndLeakyRelu = FusedBatchNormOutputKernel<T, LeakyRelu>;
369 
370 template <typename T>
371 Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs<T>* args,
372                        const float* leakyrelu_alpha = nullptr) {
373   // Bias of the following dimensions: [ output_depth ]
374   const Tensor& bias = context->input(2);
375 
376   if (bias.dims() != 1)
377     return errors::InvalidArgument("bias must be 1-dimensional",
378                                    bias.shape().DebugString());
379 
380   const auto data_ptr = [](const Tensor& tensor) -> const T* {
381     return reinterpret_cast<const T*>(tensor.tensor_data().data());
382   };
383 
384   args->bias_add_data = data_ptr(bias);
385 
386   if (leakyrelu_alpha) {
387     args->leakyrelu_alpha = *leakyrelu_alpha;
388   }
389 
390   return OkStatus();
391 }
392 
393 template <typename T>
394 Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon,
395                               FusedBatchNormArgs<T>* args,
396                               const float* leakyrelu_alpha = nullptr) {
397   const Tensor& scale = context->input(2);
398   const Tensor& offset = context->input(3);
399   const Tensor& estimated_mean = context->input(4);
400   const Tensor& estimated_variance = context->input(5);
401 
402   if (scale.dims() != 1)
403     return errors::InvalidArgument("scale must be 1-dimensional",
404                                    scale.shape().DebugString());
405   if (offset.dims() != 1)
406     return errors::InvalidArgument("offset must be 1-dimensional",
407                                    offset.shape().DebugString());
408   if (estimated_mean.dims() != 1)
409     return errors::InvalidArgument("estimated_mean must be 1-dimensional",
410                                    estimated_mean.shape().DebugString());
411   if (estimated_variance.dims() != 1)
412     return errors::InvalidArgument("estimated_variance must be 1-dimensional",
413                                    estimated_variance.shape().DebugString());
414 
415   const auto data_ptr = [](const Tensor& tensor) -> const T* {
416     return reinterpret_cast<const T*>(tensor.tensor_data().data());
417   };
418 
419   args->scale_data = data_ptr(scale);
420   args->offset_data = data_ptr(offset);
421   args->estimated_mean_data = data_ptr(estimated_mean);
422   args->estimated_variance_data = data_ptr(estimated_variance);
423 
424   // Precompute scaling factor once for all output blocks (kernels).
425   args->scaling_factor =
426       (estimated_variance.flat<T>() + static_cast<T>(epsilon)).rsqrt() *
427       scale.flat<T>();
428 
429   if (leakyrelu_alpha) {
430     args->leakyrelu_alpha = *leakyrelu_alpha;
431   }
432 
433   return OkStatus();
434 }
435 
436 }  // namespace tensorflow
437 
438 #endif  // TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
439