1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_ 17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_ 18 19 // If TFLITE_WITH_RUY is set, Ruy is the only GEMM option. In this header 20 // we select either Ruy or an alternative based on the SIMD extentions 21 // available on the given x86 platform. 22 #ifndef TFLITE_WITH_RUY 23 24 #include "tensorflow/lite/kernels/cpu_backend_context.h" 25 #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" 26 #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h" 27 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" 28 #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" 29 #include "tensorflow/lite/kernels/internal/compatibility.h" 30 31 namespace tflite { 32 namespace cpu_backend_gemm { 33 namespace detail { 34 35 template <typename LhsScalar, typename RhsScalar, typename AccumScalar, 36 typename DstScalar, QuantizationFlavor quantization_flavor> 37 struct GemmImplX86 { RunGemmImplX8638 static void Run( 39 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, 40 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, 41 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, 42 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, 43 CpuBackendContext* context) { 44 // TODO(b/168923364) Ruy is preferred on x86, but check if the deprecated 45 // path is enabled. 46 if (context->PreferGemmlowpOnX86()) { 47 // Dispatch to gemmlowp. 48 detail::GemmImplUsingGemmlowp< 49 LhsScalar, RhsScalar, AccumScalar, DstScalar, 50 quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data, 51 dst_params, dst_data, params, context); 52 53 return; 54 } 55 // Run-time dispatch to Ruy for platforms with AVX or above. 56 detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar, 57 quantization_flavor>::Run(lhs_params, lhs_data, 58 rhs_params, rhs_data, 59 dst_params, dst_data, 60 params, context); 61 } 62 }; 63 64 // For float, defer to eigen for now. 65 template <> 66 struct GemmImplX86<float, float, float, float, 67 QuantizationFlavor::kFloatingPoint> { 68 static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data, 69 const MatrixParams<float>& rhs_params, const float* rhs_data, 70 const MatrixParams<float>& dst_params, float* dst_data, 71 const GemmParams<float, float, 72 QuantizationFlavor::kFloatingPoint>& params, 73 CpuBackendContext* context) { 74 GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, 75 dst_params, dst_data, params, context); 76 } 77 }; 78 79 // gemmlowp requires NEON for certain quantization cases. See note in 80 // cpu_backend_gemm.h 81 #if !defined(GEMMLOWP_NEON) 82 template <typename SrcScalar, QuantizationFlavor quantization_flavor> 83 struct GemmImplX86<SrcScalar, SrcScalar, std::int32_t, std::int8_t, 84 quantization_flavor> 85 : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t, 86 quantization_flavor> {}; 87 88 template <typename DstScalar, QuantizationFlavor quantization_flavor> 89 struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, DstScalar, 90 quantization_flavor> 91 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t, 92 DstScalar, quantization_flavor> {}; 93 94 template <QuantizationFlavor quantization_flavor> 95 struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, std::int8_t, 96 quantization_flavor> 97 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t, 98 std::int8_t, quantization_flavor> {}; 99 #endif // not GEMMLOWP_NEON 100 } // namespace detail 101 } // namespace cpu_backend_gemm 102 } // namespace tflite 103 104 #endif // not TFLITE_WITH_RUY 105 106 #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_ 107