xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_x86.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_
18 
19 // If TFLITE_WITH_RUY is set, Ruy is the only GEMM option. In this header
20 // we select either Ruy or an alternative based on the SIMD extentions
21 // available on the given x86 platform.
22 #ifndef TFLITE_WITH_RUY
23 
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
26 #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
27 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
28 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 
31 namespace tflite {
32 namespace cpu_backend_gemm {
33 namespace detail {
34 
35 template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
36           typename DstScalar, QuantizationFlavor quantization_flavor>
37 struct GemmImplX86 {
RunGemmImplX8638   static void Run(
39       const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
40       const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
41       const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
42       const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
43       CpuBackendContext* context) {
44     // TODO(b/168923364) Ruy is preferred on x86, but check if the deprecated
45     // path is enabled.
46     if (context->PreferGemmlowpOnX86()) {
47       // Dispatch to gemmlowp.
48       detail::GemmImplUsingGemmlowp<
49           LhsScalar, RhsScalar, AccumScalar, DstScalar,
50           quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
51                                     dst_params, dst_data, params, context);
52 
53       return;
54     }
55     // Run-time dispatch to Ruy for platforms with AVX or above.
56     detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
57                              quantization_flavor>::Run(lhs_params, lhs_data,
58                                                        rhs_params, rhs_data,
59                                                        dst_params, dst_data,
60                                                        params, context);
61   }
62 };
63 
64 // For float, defer to eigen for now.
65 template <>
66 struct GemmImplX86<float, float, float, float,
67                    QuantizationFlavor::kFloatingPoint> {
68   static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
69                   const MatrixParams<float>& rhs_params, const float* rhs_data,
70                   const MatrixParams<float>& dst_params, float* dst_data,
71                   const GemmParams<float, float,
72                                    QuantizationFlavor::kFloatingPoint>& params,
73                   CpuBackendContext* context) {
74     GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
75                             dst_params, dst_data, params, context);
76   }
77 };
78 
79 // gemmlowp requires NEON for certain quantization cases. See note in
80 // cpu_backend_gemm.h
81 #if !defined(GEMMLOWP_NEON)
82 template <typename SrcScalar, QuantizationFlavor quantization_flavor>
83 struct GemmImplX86<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
84                    quantization_flavor>
85     : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
86                                quantization_flavor> {};
87 
88 template <typename DstScalar, QuantizationFlavor quantization_flavor>
89 struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, DstScalar,
90                    quantization_flavor>
91     : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
92                                DstScalar, quantization_flavor> {};
93 
94 template <QuantizationFlavor quantization_flavor>
95 struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
96                    quantization_flavor>
97     : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
98                                std::int8_t, quantization_flavor> {};
99 #endif  // not GEMMLOWP_NEON
100 }  // namespace detail
101 }  // namespace cpu_backend_gemm
102 }  // namespace tflite
103 
104 #endif  // not TFLITE_WITH_RUY
105 
106 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_
107