xref: /aosp_15_r20/external/ruy/ruy/mul_params.h (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef RUY_RUY_MUL_PARAMS_H_
17 #define RUY_RUY_MUL_PARAMS_H_
18 
19 #include <cstdint>
20 #include <limits>
21 #include <type_traits>
22 
23 #include "ruy/check_macros.h"
24 #include "ruy/size_util.h"
25 
26 namespace ruy {
27 
28 // Enumeration to designate which dimension is the 'channels', for MulParams
29 // features that are 'per-channel', namely the bias-vector and the quantized
30 // multiplier.
31 enum class ChannelDimension : std::int8_t {
32   // kRow means that 'per-channel' means 'per row of the destination matrix'
33   kRow,
34   // kCol means that 'per-channel' means 'per column of the destination matrix'
35   kCol
36 };
37 
38 namespace detail {
39 template <typename tAccumScalar, typename tDstScalar>
40 struct MulParamsStorage;
41 }
42 
43 // MulParams describes all about a matrix multiplication that
44 // isn't encoded in the LHS, RHS and destination matrices. Some of that
45 // information is encoded as compile-time constants and types (for instance, the
46 // choice of accumulator type, AccumScalar). Some of that information is encoded
47 // as runtime values (for instance, the optional bias vector).
48 //
49 // Template parameters:
50 // AccumScalar: Accumulator type. The type of accumulators used to compute the
51 // dot-products before being ultimately casted to the destination type.
52 // DstScalar: The destination scalar type.
53 //
54 // Constraints on these template parameters (see also the ruy::Mul comment):
55 // * If DstScalar is floating-point then AccumScalar must also be.
56 // * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover
57 //   in that integral case, there is a mode switch:
58 //   - If DstScalar is std::int32_t then the multiplier_* fields are all
59 //     disabled, and ruy::Mul will just return raw (unscaled) accumulators.
60 //   - If DstScalar is not std::int32_t then the multiplier_* fields are
61 //     enabled, and ruy::Mul will use them to scale internal std::int32_t
62 //     accumulators before casting them to the DstScalar type. The default
63 //     values are such that the effective multiplier is 1 (no scaling).
64 //
65 // For the latter case (DstScalar integral and narrower than std::int32_t),
66 // reference code can be found in the implementation of ruy::ApplyMultiplier.
67 // If you look there, you'll find warnings like this:
68 //
69 //   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
70 //   Warning: this code is not meant to be bit-exact-normative.
71 //   Please refer to the class comment of ruy::MulParams, in mul_params.h.
72 //   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
73 //
74 // The explanation of this warning is that as of early 2021, we still don't know
75 // whether it is advisable to let this code as-is have normative value, or
76 // whether that would become advisable after some specific final change.
77 //
78 // Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform
79 // bit-exactly to this reference, but we also know that x86 could be faster if
80 // it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't
81 // know that this particular reference code is inherently better than other
82 // forms that could perform better on these architectures --- in fact, the
83 // alternative that was proposed in [2] as better performing on ARM Cortex-M
84 // is also inherently more accurate thanks to rounding only once, but it would
85 // perform worse on both ARM NEON, and x86.
86 //
87 // In fact, if we look at other hardware architectures beyond current Ruy
88 // targets, namely "hardware accelerators", it becomes clear that there is no
89 // hope for any form of this to be efficiently implementable simultaneously on
90 // all current relevant hardware. Indeed, some accelerators prefer to perform
91 // the multiplication in IEEE float32, others in IEEE float16, others in
92 // bfloat16, others in 16-bit fixed-point...
93 //
94 // See:
95 //   [1] https://github.com/google/ruy/pull/227
96 //   [2] https://github.com/tensorflow/tensorflow/issues/25087
97 template <typename tAccumScalar, typename tDstScalar>
98 class MulParams final {
99  public:
100   using AccumScalar = tAccumScalar;
101   using DstScalar = tDstScalar;
102 
103   // The bias vector data, if not null.
bias()104   const AccumScalar* bias() const { return storage_.bias; }
set_bias(const AccumScalar * ptr)105   void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
106   // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
107   // of the multiplier by which accumulators are multiplied before being casted
108   // to the destination type.
multiplier_fixedpoint()109   AccumScalar multiplier_fixedpoint() const {
110     return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
111   }
set_multiplier_fixedpoint(const AccumScalar value)112   void set_multiplier_fixedpoint(const AccumScalar value) {
113     set_perchannel(false);
114     storage_.multiplier_fixedpoint = value;
115   }
116   // Only for non-floating-point cases. The exponent part of the aforementioned
117   // multiplier.
multiplier_exponent()118   int multiplier_exponent() const {
119     return storage_.perchannel ? 0 : storage_.multiplier_exponent;
120   }
set_multiplier_exponent(const int value)121   void set_multiplier_exponent(const int value) {
122     set_perchannel(false);
123     storage_.multiplier_exponent = value;
124   }
125   // Per-channel variant of multiplier_fixedpoint. Setting this switches
126   // to per-channel mode, where `multiplier_fixedpoint` and
127   // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
128   // and `multiplier_exponent_perchannel` are used instead.
129   //
130   // This must point to a buffer of as many values as there are rows in the
131   // destination matrix. Each row of the destination matrix will use the
132   // corresponding buffer element instead of multiplier_fixedpoint.
multiplier_fixedpoint_perchannel()133   const AccumScalar* multiplier_fixedpoint_perchannel() const {
134     return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
135                                : nullptr;
136   }
set_multiplier_fixedpoint_perchannel(const AccumScalar * ptr)137   void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) {
138     set_perchannel(true);
139     storage_.multiplier_fixedpoint_perchannel = ptr;
140   }
141   // Per-channel variant of multiplier_exponent. Same comments as for
142   // multiplier_fixedpoint_perchannel.
multiplier_exponent_perchannel()143   const int* multiplier_exponent_perchannel() const {
144     return storage_.perchannel ? storage_.multiplier_exponent_perchannel
145                                : nullptr;
146   }
set_multiplier_exponent_perchannel(const int * ptr)147   void set_multiplier_exponent_perchannel(const int* ptr) {
148     set_perchannel(true);
149     storage_.multiplier_exponent_perchannel = ptr;
150   }
151   // min clamp bound of destination values.
clamp_min()152   DstScalar clamp_min() const { return storage_.clamp_min; }
set_clamp_min(const DstScalar value)153   void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; }
154   // max clamp bound of destination values.
clamp_max()155   DstScalar clamp_max() const { return storage_.clamp_max; }
set_clamp_max(const DstScalar value)156   void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; }
157   // Designates which dimension is the 'channels', for per-channel features
158   // such as bias-addition and per-channel quantization multipliers.
channel_dimension()159   ChannelDimension channel_dimension() const {
160     return storage_.channel_dimension;
161   }
set_channel_dimension(ChannelDimension value)162   void set_channel_dimension(ChannelDimension value) {
163     storage_.channel_dimension = value;
164   }
165   // Specifies the upward rounding of the allocated capacity of per-channel
166   // buffers such as bias vectors and per-channel quantization multipliers.
167   // The unit is matrix entries, not bytes.
168   //
169   // This value must be a power of two.
170   //
171   // The default value, 1, means no upward rounding, meaning that the buffers
172   // are not required to have a capacity greater than the size of the
173   // corresponding matrix dimension, i.e. the number of rows (respectively
174   // columns) of the destination matrix if `channel_dimension()` is kRow
175   // (respectively kCol).
176   //
177   // Higher values allow the implementation to assume that it is OK to access
178   // these buffers a little past this boundary, which is useful in SIMD
179   // optimized kernels. In practice, when this value is lower than what the
180   // kernel requires, ruy has to internally reallocate and copy per-channel
181   // buffers. When this value is high enough, this reallocation and copy is
182   // avoided.
183   //
184   // When a value greater than 1 is specified, the tail region of the buffer
185   // (past the end of the values actually corresponding to channels) is required
186   // to be zero-initialized.
187   //
188   // As of 2020, values as high as 16 may be useful on some CPU architectures
189   // (corresponding to the widest kernels used on any CPU architecture).
perchannel_buffers_capacity_rounding()190   int perchannel_buffers_capacity_rounding() const {
191     return 1 << storage_.perchannel_buffers_capacity_rounding_log2;
192   }
set_perchannel_buffers_capacity_rounding(int value)193   void set_perchannel_buffers_capacity_rounding(int value) {
194     // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two.
195     storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value);
196   }
197 
198  private:
199   detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
200 
set_perchannel(bool perchannel)201   void set_perchannel(bool perchannel) {
202     if (storage_.perchannel == perchannel) {
203       return;
204     }
205     if (perchannel) {
206       RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0);
207       RUY_DCHECK_EQ(storage_.multiplier_exponent, 0);
208     } else {
209       RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr);
210       RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr);
211     }
212     storage_.perchannel = perchannel;
213   }
214 };
215 
216 namespace detail {
217 
218 // Floating-point case.
219 template <typename AccumScalar, typename DstScalar>
220 struct MulParamsStorage final {
221   static_assert(std::is_floating_point<AccumScalar>::value, "");
222   static_assert(std::is_floating_point<DstScalar>::value, "");
223   static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), "");
224 
225   const AccumScalar* bias = nullptr;
226   DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity();
227   DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity();
228   ChannelDimension channel_dimension = ChannelDimension::kRow;
229   std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
230 
231   // Data members that are disabled in this case are left as `static constexpr`
232   // so that one can write some generic code.
233   static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
234       nullptr;
235   static constexpr const int* multiplier_exponent_perchannel = nullptr;
236   static constexpr AccumScalar multiplier_fixedpoint = 0;
237   static constexpr int multiplier_exponent = 0;
238   static constexpr bool perchannel = false;
239 };
240 
241 // Specialization for the integer-quantized type, with down-quantization of
242 // int32 accumulators to a narrower destination scalar type.
243 template <typename DstScalar>
244 struct MulParamsStorage<std::int32_t, DstScalar> final {
245   using AccumScalar = std::int32_t;
246   static_assert(std::is_integral<DstScalar>::value, "");
247   static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
248 
249   const AccumScalar* bias = nullptr;
250   // union {  // This used to be a union, temporarily flattened to debug a crash
251   const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
252   // Let the default multiplier be effecively a multiplication by 1, so that
253   // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
254   // 1 is not exactly representable in fixedpoint with 0 integer bits, but
255   // using the highest representable value is a sufficiently good
256   // approximation: since this specialization of MulParams is for the case
257   // where DstScalar is at least 2x narrower than MulScalar, the values
258   // for which there would be a difference will get saturated anyway.
259   AccumScalar multiplier_fixedpoint = 0;
260   //};
261   // union {  // This used to be a union, temporarily flattened to debug a crash
262   const int* multiplier_exponent_perchannel = nullptr;
263   // See the above comment about the default value of multiplier_fixedpoint.
264   int multiplier_exponent = 0;
265   // };
266   DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
267   DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
268   ChannelDimension channel_dimension = ChannelDimension::kRow;
269   bool perchannel = false;
270   std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
271 };
272 
273 // Specialization used in the integer case when outputting raw int32
274 // accumulators, without down-quantization to a narrower destination scalar
275 // type. In this case, the feature of clamping destination values is not
276 // available.
277 template <>
278 struct MulParamsStorage<std::int32_t, std::int32_t> final {
279   using AccumScalar = std::int32_t;
280   using DstScalar = std::int32_t;
281 
282   const AccumScalar* bias = nullptr;
283   ChannelDimension channel_dimension = ChannelDimension::kRow;
284   std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
285 
286   // Data members that are disabled in this case are left as `static constexpr`
287   // so that one can write some generic code.
288   static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
289       nullptr;
290   static constexpr const int* multiplier_exponent_perchannel = nullptr;
291   static constexpr AccumScalar multiplier_fixedpoint = 0;
292   static constexpr int multiplier_exponent = 0;
293   static constexpr DstScalar clamp_min =
294       std::numeric_limits<DstScalar>::lowest();
295   static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
296   static constexpr bool perchannel = false;
297 };
298 
299 }  // namespace detail
300 
301 }  // namespace ruy
302 
303 #endif  // RUY_RUY_MUL_PARAMS_H_
304