xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantized_matmul_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements a quantized eight-bit version of the matmul operation.
17 
18 #define EIGEN_USE_THREADS
19 
20 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
21 #include "public/gemmlowp.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/meta_support.h"
27 #include "tensorflow/core/kernels/quantization_utils.h"
28 #include "tensorflow/core/kernels/reference_gemm.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/errors.h"
31 
32 namespace tensorflow {
33 
34 // We have to break this out as a separate function because there are multiple
35 // combinations of transpose attributes we need to support, and they have to be
36 // compile-time constants to work with the templates used internally.
37 template <bool TransposeA, bool TransposeB, bool TransposeC>
GemmlowpMultiply(OpKernelContext * op_context,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)38 void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data,
39                       const quint8* b_data, qint32* c_data, int m, int n, int k,
40                       int offset_a, int offset_b, int lda, int ldb, int ldc) {
41   const uint8* a_data_as_uint8 = &(a_data->value);
42   const uint8* b_data_as_uint8 = &(b_data->value);
43   int32* c_data_as_int32 = &(c_data->value);
44   static const gemmlowp::MapOrder ResultOrder =
45       !TransposeC ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
46   static const gemmlowp::MapOrder LhsOrder =
47       !TransposeA ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
48   static const gemmlowp::MapOrder RhsOrder =
49       !TransposeB ? gemmlowp::MapOrder::RowMajor : gemmlowp::MapOrder::ColMajor;
50   gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(a_data_as_uint8, m, k,
51                                                         lda);
52   gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(b_data_as_uint8, k, n,
53                                                         ldb);
54   gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
55                                                         ldc);
56   const std::tuple<> empty_pipeline = {};
57   auto& worker_threads =
58       *(op_context->device()->tensorflow_cpu_worker_threads());
59   TensorflowGemmContext context(worker_threads.num_threads,
60                                 worker_threads.workers);
61   gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
62                                    gemmlowp::DefaultL8R8BitDepthParams>(
63       &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
64   // Since gemmlowp uses assembly to write to the output, msan won't detect
65   // the output buffer as written to, so we mark it manually.
66   TF_ANNOTATE_MEMORY_IS_INITIALIZED(c_data_as_int32, m * n * sizeof(int32));
67 }
68 
69 template <class T1, class T2, class Toutput>
70 class QuantizedMatMulOp : public OpKernel {
71  public:
QuantizedMatMulOp(OpKernelConstruction * context)72   explicit QuantizedMatMulOp(OpKernelConstruction* context)
73       : OpKernel(context) {
74     OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
75     OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
76   }
77 
Compute(OpKernelContext * context)78   void Compute(OpKernelContext* context) override {
79     const Tensor& a = context->input(0);
80     const Tensor& b = context->input(1);
81     OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()),
82                 errors::InvalidArgument("min_a must be a scalar, but got shape",
83                                         context->input(2).shape()));
84     const float min_a = context->input(2).flat<float>()(0);
85     OP_REQUIRES(context, context->input(3).NumElements() == 1,
86                 errors::InvalidArgument("max_a must be a scalar, but got shape",
87                                         context->input(3).shape()));
88     const float max_a = context->input(3).flat<float>()(0);
89     OP_REQUIRES(context, context->input(4).NumElements() == 1,
90                 errors::InvalidArgument("min_b must be a scalar, but got shape",
91                                         context->input(4).shape()));
92     const float min_b = context->input(4).flat<float>()(0);
93     OP_REQUIRES(context, context->input(5).NumElements() == 1,
94                 errors::InvalidArgument("max_b must be a scalar, but got shape",
95                                         context->input(5).shape()));
96     const float max_b = context->input(5).flat<float>()(0);
97 
98     // Make sure that we have valid quantization ranges for the input buffers.
99     // If the difference between the min and max is negative or zero, it makes
100     // it hard to do meaningful intermediate operations on the values.
101     OP_REQUIRES(context, (max_a > min_a),
102                 errors::InvalidArgument("max_a must be larger than min_a."));
103     OP_REQUIRES(context, (max_b > min_b),
104                 errors::InvalidArgument("max_b must be larger than min_b."));
105     const int32_t offset_a = FloatToQuantizedUnclamped<T1>(0.0f, min_a, max_a);
106     const int32_t offset_b = FloatToQuantizedUnclamped<T2>(0.0f, min_b, max_b);
107     const int32_t offset_c = 0;
108     const int32_t mult_c = 1;
109     const int32_t shift_c = 0;
110 
111     // Check that the dimensions of the two matrices are valid.
112     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(a.shape()),
113                 errors::InvalidArgument("In[0] is not a matrix"));
114     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(b.shape()),
115                 errors::InvalidArgument("In[1] is not a matrix"));
116     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
117     dim_pair[0].first = transpose_a_ ? 0 : 1;
118     dim_pair[0].second = transpose_b_ ? 1 : 0;
119 
120     OP_REQUIRES(context,
121                 a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
122                 errors::InvalidArgument("Matrix size-incompatible: In[0]: ",
123                                         a.shape().DebugString(),
124                                         ", In[1]: ", b.shape().DebugString()));
125 
126     OP_REQUIRES(context, ((shift_c >= 0) && (shift_c <= 31)),
127                 errors::InvalidArgument("shift_c must be between 0 and 31, "
128                                         "inclusive."));
129 
130     int a_dim_remaining = 1 - dim_pair[0].first;
131     int b_dim_remaining = 1 - dim_pair[0].second;
132     TensorShape out_shape(
133         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
134     Tensor* c = nullptr;
135     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
136     CHECK(c);
137 
138     const T1* a_data = a.flat<T1>().data();
139     const T2* b_data = b.flat<T2>().data();
140     Toutput* c_data = c->flat<Toutput>().data();
141 
142     const bool transpose_c = false;
143     const size_t m = a.dim_size(a_dim_remaining);
144     const size_t n = b.dim_size(b_dim_remaining);
145     const size_t k = a.dim_size(dim_pair[0].first);
146     const size_t lda = a.dim_size(1);
147     const size_t ldb = b.dim_size(1);
148     const size_t ldc = n;
149 
150     if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
151         std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() &&
152         (offset_c == 0) && (mult_c == 1) && (shift_c == 0) &&
153         (transpose_c == false) && (k <= 2048)) {
154       // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and
155       // allows optimized quantized 8bit to 32bit gemm.
156       meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data,
157                           c_data, m, n, k, -offset_a, -offset_b, lda, ldb, ldc);
158     } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
159                std::is_same<Toutput, qint32>() && (offset_c == 0) &&
160                (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) {
161       // The gemmlowp optimized library only works for a particular set of data
162       // types, so check if we meet those requirements and fall back to a slower
163       // reference implementation if not.
164       if (transpose_a_) {
165         if (transpose_b_) {
166           GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
167                                               m, n, k, offset_a, offset_b, lda,
168                                               ldb, ldc);
169         } else {
170           GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data,
171                                                m, n, k, offset_a, offset_b, lda,
172                                                ldb, ldc);
173         }
174       } else {
175         if (transpose_b_) {
176           GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data,
177                                                m, n, k, offset_a, offset_b, lda,
178                                                ldb, ldc);
179         } else {
180           GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data,
181                                                 m, n, k, offset_a, offset_b,
182                                                 lda, ldb, ldc);
183         }
184       }
185     } else {
186       ReferenceGemm<T1, T2, Toutput>(
187           transpose_a_, transpose_b_, transpose_c, m, n, k, a_data, offset_a,
188           lda, b_data, offset_b, ldb, c_data, shift_c, offset_c, mult_c, ldc);
189     }
190 
191     float min_c_value;
192     float max_c_value;
193     QuantizationRangeForMultiplication<T1, T2, Toutput>(
194         min_a, max_a, min_b, max_b, &min_c_value, &max_c_value);
195     Tensor* c_min = nullptr;
196     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &c_min));
197     c_min->flat<float>()(0) = min_c_value;
198 
199     Tensor* c_max = nullptr;
200     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &c_max));
201     c_max->flat<float>()(0) = max_c_value;
202   }
203 
204  private:
205   bool transpose_a_;
206   bool transpose_b_;
207 };
208 
209 REGISTER_KERNEL_BUILDER(Name("QuantizedMatMul")
210                             .Device(DEVICE_CPU)
211                             .TypeConstraint<quint8>("T1")
212                             .TypeConstraint<quint8>("T2")
213                             .TypeConstraint<qint32>("Toutput"),
214                         QuantizedMatMulOp<quint8, quint8, qint32>);
215 
216 }  // namespace tensorflow
217