xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/matmul_op_fused.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements matmul operations with other kernels baked into the
17 // processing, to optimize latency and memory usage:
18 //  - MatMul + BiasAdd + <Activation>
19 //  - MatMul + FusedBatchNorm + <Activation>
20 //
21 // Activation: Relu, Relu6, Elu, etc...
22 //
23 // Currently supported only on CPU device.
24 
25 #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
26 #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
27 
28 #define USE_EIGEN_TENSOR
29 #define EIGEN_USE_THREADS
30 
31 #if GOOGLE_CUDA
32 #define EIGEN_USE_GPU
33 #endif  // GOOGLE_CUDA
34 
35 #include <string>
36 #include <utility>
37 #include <vector>
38 
39 #include "tensorflow/core/framework/bounds_check.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/register_types.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/tensor_shape.h"
44 #include "tensorflow/core/kernels/fill_functor.h"
45 #include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
46 #include "tensorflow/core/platform/errors.h"
47 #include "tensorflow/core/util/matmul_autotune.h"
48 #include "tensorflow/core/util/tensor_format.h"
49 
50 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
51 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
52 #endif
53 
54 #if GOOGLE_CUDA
55 #include "tensorflow/core/kernels/gpu_utils.h"
56 #include "tensorflow/core/kernels/matmul_op_impl.h"
57 #include "tensorflow/core/kernels/matmul_util.h"
58 #include "tensorflow/core/platform/stream_executor.h"
59 #include "tensorflow/core/platform/tensor_float_32_utils.h"
60 #endif  // GOOGLE_CUDA
61 
62 namespace tensorflow {
63 
64 typedef Eigen::ThreadPoolDevice CPUDevice;
65 typedef Eigen::GpuDevice GPUDevice;
66 
67 template <typename Device, typename T>
68 struct LaunchFusedMatMulOp {
69   void operator()(
70       OpKernelContext* context, const Tensor& a, const Tensor& b,
71       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
72       FusedComputationType fusion, const FusedComputationArgs& fusion_args,
73       Tensor* output, bool use_autotune);
74 };
75 
76 template <typename T>
77 struct LaunchFusedMatMulOp<CPUDevice, T> {
operator ()tensorflow::LaunchFusedMatMulOp78   void operator()(
79       OpKernelContext* context, const Tensor& a, const Tensor& b,
80       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
81       FusedComputationType fusion, const FusedComputationArgs& fusion_args,
82       Tensor* output, bool use_autotune) {
83     OP_REQUIRES(context, DataTypeToEnum<T>::value != DT_HALF,
84                 errors::InvalidArgument("_FusedMatMul doesn't support DT_HALF "
85                                         "data type on CPU devices."));
86     auto lhs = a.matrix<T>();
87     auto rhs = b.matrix<T>();
88     auto out = output->matrix<T>();
89 
90     auto& d = context->eigen_device<CPUDevice>();
91 
92     // Executes Eigen contraction with output kernel wrapped into type erased
93     // wrapper to reduce the number of unique template instantiations.
94     auto executeWithOutputKernel = [&](auto output_kernel) {
95       OutputKernelWrapper output_kernel_wrapper(
96           [&output_kernel](
97               const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
98               const Eigen::TensorContractionParams& params, Eigen::Index i,
99               Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
100             output_kernel(output_mapper, params, i, j, num_rows, num_cols);
101           });
102 
103       out.device(d) = lhs.contract(rhs, dim_pair, output_kernel_wrapper);
104     };
105 
106     BiasAddArgs<T> bias_add_args;
107     if (BiasAddArgs<T>::IsSupported(fusion)) {
108       if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
109         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
110                                                 &fusion_args.leakyrelu_alpha));
111       } else {
112         OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
113       }
114     }
115 
116     switch (fusion) {
117       case FusedComputationType::kBiasAdd:
118         executeWithOutputKernel(WithBiasAdd<T>(bias_add_args));
119         break;
120       case FusedComputationType::kBiasAddWithRelu:
121         executeWithOutputKernel(WithBiasAddAndRelu<T>(bias_add_args));
122         break;
123       case FusedComputationType::kBiasAddWithRelu6:
124         executeWithOutputKernel(WithBiasAddAndRelu6<T>(bias_add_args));
125         break;
126       case FusedComputationType::kBiasAddWithElu:
127         executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
128         break;
129       case FusedComputationType::kBiasAddWithLeakyRelu:
130         executeWithOutputKernel(WithBiasAddAndLeakyRelu<T>(bias_add_args));
131         break;
132       case FusedComputationType::kUndefined:
133         OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
134         break;
135       default:
136         OP_REQUIRES_OK(context,
137                        errors::Internal("Fusion type is not supported"));
138     }
139   }
140 
141  private:
142   // Wrap output_kernel into type erased struct to reduce the number of unique
143   // template instantiations for Eigen Tensor contraction expressions.
144   //
145   // We do not pass std::function directly as an output kernel because it blows
146   // up the binary size in debug mode with super long symbol names.
147   struct OutputKernelWrapper {
148     using OutputKernelFn =
149         std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
150                            const Eigen::TensorContractionParams&, Eigen::Index,
151                            Eigen::Index, Eigen::Index, Eigen::Index)>;
152 
OutputKernelWrappertensorflow::LaunchFusedMatMulOp::OutputKernelWrapper153     explicit OutputKernelWrapper(OutputKernelFn fn)
154         : output_kernel_fn(std::move(fn)) {}
155 
operator ()tensorflow::LaunchFusedMatMulOp::OutputKernelWrapper156     void operator()(
157         const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
158         const Eigen::TensorContractionParams& params, Eigen::Index i,
159         Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
160       output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
161     }
162 
163     OutputKernelFn output_kernel_fn;
164   };
165 };
166 
167 #if GOOGLE_CUDA
168 namespace {
169 
GetBlasLtEpilogOp(FusedComputationType fusion)170 StatusOr<se::cuda::BlasLt::Epilogue> GetBlasLtEpilogOp(
171     FusedComputationType fusion) {
172   if (fusion == FusedComputationType::kBiasAdd) {
173     return se::cuda::BlasLt::Epilogue::kBias;
174   } else if (fusion == FusedComputationType::kBiasAddWithRelu) {
175     return se::cuda::BlasLt::Epilogue::kBiasThenReLU;
176   } else if (fusion == FusedComputationType::kBiasAddWithGeluApproximate) {
177     return se::cuda::BlasLt::Epilogue::kBiasThenGeLUApproximate;
178   } else {
179     return errors::Internal("Unsupported fusion for BlasLt Matmul");
180   }
181 }
182 
183 template <typename LaunchFunc>
AutotuneMatmul(const std::vector<se::cuda::BlasLt::MatmulAlgorithm> & algorithms,BlasLtMatmulPlanParams & matmul_params,OpKernelContext * context,const LaunchFunc & launch_func)184 se::blas::AlgorithmConfig AutotuneMatmul(
185     const std::vector<se::cuda::BlasLt::MatmulAlgorithm>& algorithms,
186     BlasLtMatmulPlanParams& matmul_params, OpKernelContext* context,
187     const LaunchFunc& launch_func) {
188   // Note that algorithm_config.algorithm() here is used to refer
189   // to the index within the algorithms vector, not the algorithm
190   // itself.
191   se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
192   if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_params,
193                                                 &algorithm_config)) {
194     VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
195             << " algorithms.";
196     se::blas::ProfileResult best_result;
197     se::blas::ProfileResult profile_result;
198 
199     for (size_t i = 0; i != algorithms.size(); ++i) {
200       const auto& profile_algorithm = algorithms[i];
201 
202       // Create a new scratch allocator with every autotuning run so that
203       // scratch space is deallocated between runs.
204       BlasScratchAllocator scratch_allocator(context);
205 
206       Status cublaslt_launch =
207           launch_func(scratch_allocator, profile_algorithm, &profile_result);
208 
209       VLOG(4) << "  Autotune algorithm " << i
210               << " result: " << profile_result.elapsed_time_in_ms()
211               << " ms, valid=" << profile_result.is_valid()
212               << ", workspace_size=" << profile_algorithm.workspace_size;
213 
214       if (cublaslt_launch.ok() && profile_result.is_valid() &&
215           profile_result.elapsed_time_in_ms() <
216               best_result.elapsed_time_in_ms()) {
217         best_result = profile_result;
218         // Use index into algorithms array, instead of cublas internal ID.
219         best_result.set_algorithm(i);
220       }
221     }
222 
223     if (best_result.is_valid()) {
224       algorithm_config.set_algorithm(best_result.algorithm());
225     }
226     // We make sure that each matmul parameter set only gets one pass of
227     // autotune. If no algorithms works, we add kNoAlgorithm to the autotune
228     // map.
229     AutoTuneBatchMatmul::GetInstance()->Insert(matmul_params, algorithm_config);
230   }
231   return algorithm_config;
232 }
233 
234 }  // namespace
235 
236 template <typename T>
237 struct LaunchFusedMatMulOp<GPUDevice, T> {
operator ()tensorflow::LaunchFusedMatMulOp238   void operator()(
239       OpKernelContext* context, const Tensor& a, const Tensor& b,
240       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
241       FusedComputationType fusion, const FusedComputationArgs& fusion_args,
242       Tensor* output, bool use_autotune) {
243     OP_REQUIRES(
244         context, DataTypeToEnum<T>::value != DT_BFLOAT16,
245         errors::InvalidArgument("_FusedMatMul doesn't support "
246                                 "DT_BFLOAT16 data type on CPU devices."));
247     auto* stream = context->op_device_context()->stream();
248     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
249 
250     // All fusion patterns supported by GPU are in the form of MatMul + BiasAdd
251     // + <other pointwise operations>. Therefore, the bias tensor is required.
252     const Tensor& bias = context->input(2);
253 
254     if (bias.dims() != 1) {
255       OP_REQUIRES_OK(context,
256                      errors::InvalidArgument("bias must be 1-dimensional",
257                                              bias.shape().DebugString()));
258     }
259 
260     auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
261                                 a.template flat<T>().size());
262     auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
263                                 b.template flat<T>().size());
264     auto bias_ptr = AsDeviceMemory(bias.template flat<T>().data(),
265                                    bias.template flat<T>().size());
266     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
267                                 output->template flat<T>().size());
268 
269     auto epilog_op_or = GetBlasLtEpilogOp(fusion);
270     OP_REQUIRES_OK(context, epilog_op_or.status());
271     se::cuda::BlasLt::Epilogue epilog_op = epilog_op_or.ValueOrDie();
272 
273     bool trans_a = dim_pair[0].first == 0 ? true : false;
274     bool trans_b = dim_pair[0].second == 1 ? true : false;
275 
276     const int64_t m = a.dim_size(trans_a ? 1 : 0);
277     const int64_t k = a.dim_size(trans_a ? 0 : 1);
278     const int64_t n = b.dim_size(trans_b ? 0 : 1);
279 
280     se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
281                                    se::blas::Transpose::kTranspose};
282 
283     BlasLtMatmulPlanParams matmul_params{se::blas::ToDataType<T>::value,
284                                          static_cast<size_t>(m),
285                                          static_cast<size_t>(n),
286                                          static_cast<size_t>(k),
287                                          trans[trans_a ? 1 : 0],
288                                          trans[trans_b ? 1 : 0],
289                                          /*batch_size=*/1,
290                                          /*broadcast_a=*/false,
291                                          /*broadcast_b=*/false,
292                                          epilog_op};
293 
294     auto plan_and_algorithms_or = GetPlanAndAlgorithms(stream, matmul_params);
295     OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
296     const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value();
297     const auto& plan = plan_and_algorithms->plan;
298     const auto& algorithms = plan_and_algorithms->algorithms;
299     OP_REQUIRES(context, algorithms.size() > 0,
300                 errors::InvalidArgument("No matmul algorithm returned!"));
301 
302     auto launch_func = [&](BlasScratchAllocator& scratch_allocator,
303                            const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
304                            se::blas::ProfileResult* profile_result) {
305       return DoBlasLtMatmul(stream, plan, a_ptr, b_ptr, c_ptr, algorithm,
306                             scratch_allocator, bias_ptr, profile_result);
307     };
308 
309     se::cuda::BlasLt::MatmulAlgorithm algorithm = algorithms[0];
310     if (use_autotune) {
311       se::blas::AlgorithmConfig algorithm_config =
312           AutotuneMatmul(algorithms, matmul_params, context, launch_func);
313 
314       se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
315       algorithm = algorithms[algorithm_idx];
316     }
317 
318     BlasScratchAllocator scratch_allocator(context);
319     OP_REQUIRES_OK(context, launch_func(scratch_allocator, algorithm, nullptr));
320   }
321 };
322 
323 #endif  // GOOGLE_CUDA
324 
325 template <typename Device, typename T>
326 class FusedMatMulOp : public OpKernel {
327  public:
FusedMatMulOp(OpKernelConstruction * context)328   explicit FusedMatMulOp(OpKernelConstruction* context) : OpKernel(context) {
329     OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
330     OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
331 
332     std::vector<FusedComputationPattern> patterns;
333 
334     using FCT = FusedComputationType;
335     if (std::is_same<Device, CPUDevice>::value) {
336       patterns = {
337           {FCT::kBiasAdd, {"BiasAdd"}},
338           {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
339           {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
340           {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
341           {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
342       };
343     } else if (std::is_same<Device, GPUDevice>::value) {
344       patterns = {
345           {FCT::kBiasAdd, {"BiasAdd"}},
346           {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
347           {FCT::kBiasAddWithGeluApproximate, {"BiasAdd", "GeluApproximate"}}};
348     }
349 
350     OP_REQUIRES_OK(context, InitializeFusedComputation(
351                                 context, "MatMul", patterns,
352                                 &fused_computation_, &fused_computation_args_));
353     use_autotune_ = MatmulAutotuneEnable();
354   }
355 
Compute(OpKernelContext * ctx)356   void Compute(OpKernelContext* ctx) override {
357     const Tensor& a = ctx->input(0);
358     const Tensor& b = ctx->input(1);
359 
360     // Check that the dimensions of the two matrices are valid.
361     OP_REQUIRES(ctx, a.dims() == b.dims(),
362                 errors::InvalidArgument("In[0] and In[1] has different ndims: ",
363                                         a.shape().DebugString(), " vs. ",
364                                         b.shape().DebugString()));
365     OP_REQUIRES(
366         ctx, TensorShapeUtils::IsMatrix(a.shape()),
367         errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
368                                 a.shape().DebugString()));
369     OP_REQUIRES(
370         ctx, TensorShapeUtils::IsMatrix(b.shape()),
371         errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
372                                 b.shape().DebugString()));
373     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
374     dim_pair[0].first = transpose_a_ ? 0 : 1;
375     dim_pair[0].second = transpose_b_ ? 1 : 0;
376 
377     OP_REQUIRES(
378         ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
379         errors::InvalidArgument(
380             "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
381             ", In[1]: ", b.shape().DebugString()));
382     int a_dim_remaining = 1 - dim_pair[0].first;
383     int b_dim_remaining = 1 - dim_pair[0].second;
384     TensorShape out_shape(
385         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
386     Tensor* out = nullptr;
387     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
388 
389     if (out->NumElements() == 0) {
390       // If a has shape [0, x] or b has shape [x, 0], the output shape
391       // is a 0-element matrix, so there is nothing to do.
392       return;
393     }
394 
395     if (a.NumElements() == 0 && b.NumElements() == 0) {
396       // If a has shape [x, 0] and b has shape [0, y], the
397       // output shape is [x, y] where x and y are non-zero, so we fill
398       // the output with zeros.
399       functor::SetZeroFunctor<Device, T> f;
400       f(ctx->eigen_device<Device>(), out->flat<T>());
401       return;
402     }
403 
404     auto launch = LaunchFusedMatMulOp<Device, T>();
405     launch(ctx, a, b, dim_pair, fused_computation_, fused_computation_args_,
406            out, use_autotune_);
407   }
408 
409  private:
410   bool transpose_a_;
411   bool transpose_b_;
412   bool use_autotune_;
413 
414   FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
415   FusedComputationArgs fused_computation_args_;
416 
417   TF_DISALLOW_COPY_AND_ASSIGN(FusedMatMulOp);
418 };
419 
420 // Registration of the CPU implementations.
421 #define REGISTER_FUSED_CPU_MATMUL(T)                                  \
422   REGISTER_KERNEL_BUILDER(                                            \
423       Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
424       FusedMatMulOp<CPUDevice, T>);
425 
426 TF_CALL_float(REGISTER_FUSED_CPU_MATMUL);
427 
428 #undef REGISTER_FUSED_CPU_MATMUL
429 
430 #if GOOGLE_CUDA
431 
432 // Registration of the GPU implementations.
433 #define REGISTER_FUSED_GPU_MATMUL(T)                                  \
434   REGISTER_KERNEL_BUILDER(                                            \
435       Name("_FusedMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
436       FusedMatMulOp<GPUDevice, T>);
437 
438 TF_CALL_float(REGISTER_FUSED_GPU_MATMUL);
439 TF_CALL_half(REGISTER_FUSED_GPU_MATMUL);
440 
441 #undef REGISTER_FUSED_GPU_MATMUL
442 
443 #endif  // GOOGLE_CUDA
444 
445 }  // namespace tensorflow
446 #endif  // TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
447