1 /* Copyright 2019 Google LLC. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef RUY_RUY_MUL_PARAMS_H_ 17 #define RUY_RUY_MUL_PARAMS_H_ 18 19 #include <cstdint> 20 #include <limits> 21 #include <type_traits> 22 23 #include "ruy/check_macros.h" 24 #include "ruy/size_util.h" 25 26 namespace ruy { 27 28 // Enumeration to designate which dimension is the 'channels', for MulParams 29 // features that are 'per-channel', namely the bias-vector and the quantized 30 // multiplier. 31 enum class ChannelDimension : std::int8_t { 32 // kRow means that 'per-channel' means 'per row of the destination matrix' 33 kRow, 34 // kCol means that 'per-channel' means 'per column of the destination matrix' 35 kCol 36 }; 37 38 namespace detail { 39 template <typename tAccumScalar, typename tDstScalar> 40 struct MulParamsStorage; 41 } 42 43 // MulParams describes all about a matrix multiplication that 44 // isn't encoded in the LHS, RHS and destination matrices. Some of that 45 // information is encoded as compile-time constants and types (for instance, the 46 // choice of accumulator type, AccumScalar). Some of that information is encoded 47 // as runtime values (for instance, the optional bias vector). 48 // 49 // Template parameters: 50 // AccumScalar: Accumulator type. The type of accumulators used to compute the 51 // dot-products before being ultimately casted to the destination type. 52 // DstScalar: The destination scalar type. 53 // 54 // Constraints on these template parameters (see also the ruy::Mul comment): 55 // * If DstScalar is floating-point then AccumScalar must also be. 56 // * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover 57 // in that integral case, there is a mode switch: 58 // - If DstScalar is std::int32_t then the multiplier_* fields are all 59 // disabled, and ruy::Mul will just return raw (unscaled) accumulators. 60 // - If DstScalar is not std::int32_t then the multiplier_* fields are 61 // enabled, and ruy::Mul will use them to scale internal std::int32_t 62 // accumulators before casting them to the DstScalar type. The default 63 // values are such that the effective multiplier is 1 (no scaling). 64 // 65 // For the latter case (DstScalar integral and narrower than std::int32_t), 66 // reference code can be found in the implementation of ruy::ApplyMultiplier. 67 // If you look there, you'll find warnings like this: 68 // 69 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 70 // Warning: this code is not meant to be bit-exact-normative. 71 // Please refer to the class comment of ruy::MulParams, in mul_params.h. 72 // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 73 // 74 // The explanation of this warning is that as of early 2021, we still don't know 75 // whether it is advisable to let this code as-is have normative value, or 76 // whether that would become advisable after some specific final change. 77 // 78 // Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform 79 // bit-exactly to this reference, but we also know that x86 could be faster if 80 // it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't 81 // know that this particular reference code is inherently better than other 82 // forms that could perform better on these architectures --- in fact, the 83 // alternative that was proposed in [2] as better performing on ARM Cortex-M 84 // is also inherently more accurate thanks to rounding only once, but it would 85 // perform worse on both ARM NEON, and x86. 86 // 87 // In fact, if we look at other hardware architectures beyond current Ruy 88 // targets, namely "hardware accelerators", it becomes clear that there is no 89 // hope for any form of this to be efficiently implementable simultaneously on 90 // all current relevant hardware. Indeed, some accelerators prefer to perform 91 // the multiplication in IEEE float32, others in IEEE float16, others in 92 // bfloat16, others in 16-bit fixed-point... 93 // 94 // See: 95 // [1] https://github.com/google/ruy/pull/227 96 // [2] https://github.com/tensorflow/tensorflow/issues/25087 97 template <typename tAccumScalar, typename tDstScalar> 98 class MulParams final { 99 public: 100 using AccumScalar = tAccumScalar; 101 using DstScalar = tDstScalar; 102 103 // The bias vector data, if not null. bias()104 const AccumScalar* bias() const { return storage_.bias; } set_bias(const AccumScalar * ptr)105 void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } 106 // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) 107 // of the multiplier by which accumulators are multiplied before being casted 108 // to the destination type. multiplier_fixedpoint()109 AccumScalar multiplier_fixedpoint() const { 110 return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; 111 } set_multiplier_fixedpoint(const AccumScalar value)112 void set_multiplier_fixedpoint(const AccumScalar value) { 113 set_perchannel(false); 114 storage_.multiplier_fixedpoint = value; 115 } 116 // Only for non-floating-point cases. The exponent part of the aforementioned 117 // multiplier. multiplier_exponent()118 int multiplier_exponent() const { 119 return storage_.perchannel ? 0 : storage_.multiplier_exponent; 120 } set_multiplier_exponent(const int value)121 void set_multiplier_exponent(const int value) { 122 set_perchannel(false); 123 storage_.multiplier_exponent = value; 124 } 125 // Per-channel variant of multiplier_fixedpoint. Setting this switches 126 // to per-channel mode, where `multiplier_fixedpoint` and 127 // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` 128 // and `multiplier_exponent_perchannel` are used instead. 129 // 130 // This must point to a buffer of as many values as there are rows in the 131 // destination matrix. Each row of the destination matrix will use the 132 // corresponding buffer element instead of multiplier_fixedpoint. multiplier_fixedpoint_perchannel()133 const AccumScalar* multiplier_fixedpoint_perchannel() const { 134 return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel 135 : nullptr; 136 } set_multiplier_fixedpoint_perchannel(const AccumScalar * ptr)137 void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) { 138 set_perchannel(true); 139 storage_.multiplier_fixedpoint_perchannel = ptr; 140 } 141 // Per-channel variant of multiplier_exponent. Same comments as for 142 // multiplier_fixedpoint_perchannel. multiplier_exponent_perchannel()143 const int* multiplier_exponent_perchannel() const { 144 return storage_.perchannel ? storage_.multiplier_exponent_perchannel 145 : nullptr; 146 } set_multiplier_exponent_perchannel(const int * ptr)147 void set_multiplier_exponent_perchannel(const int* ptr) { 148 set_perchannel(true); 149 storage_.multiplier_exponent_perchannel = ptr; 150 } 151 // min clamp bound of destination values. clamp_min()152 DstScalar clamp_min() const { return storage_.clamp_min; } set_clamp_min(const DstScalar value)153 void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; } 154 // max clamp bound of destination values. clamp_max()155 DstScalar clamp_max() const { return storage_.clamp_max; } set_clamp_max(const DstScalar value)156 void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; } 157 // Designates which dimension is the 'channels', for per-channel features 158 // such as bias-addition and per-channel quantization multipliers. channel_dimension()159 ChannelDimension channel_dimension() const { 160 return storage_.channel_dimension; 161 } set_channel_dimension(ChannelDimension value)162 void set_channel_dimension(ChannelDimension value) { 163 storage_.channel_dimension = value; 164 } 165 // Specifies the upward rounding of the allocated capacity of per-channel 166 // buffers such as bias vectors and per-channel quantization multipliers. 167 // The unit is matrix entries, not bytes. 168 // 169 // This value must be a power of two. 170 // 171 // The default value, 1, means no upward rounding, meaning that the buffers 172 // are not required to have a capacity greater than the size of the 173 // corresponding matrix dimension, i.e. the number of rows (respectively 174 // columns) of the destination matrix if `channel_dimension()` is kRow 175 // (respectively kCol). 176 // 177 // Higher values allow the implementation to assume that it is OK to access 178 // these buffers a little past this boundary, which is useful in SIMD 179 // optimized kernels. In practice, when this value is lower than what the 180 // kernel requires, ruy has to internally reallocate and copy per-channel 181 // buffers. When this value is high enough, this reallocation and copy is 182 // avoided. 183 // 184 // When a value greater than 1 is specified, the tail region of the buffer 185 // (past the end of the values actually corresponding to channels) is required 186 // to be zero-initialized. 187 // 188 // As of 2020, values as high as 16 may be useful on some CPU architectures 189 // (corresponding to the widest kernels used on any CPU architecture). perchannel_buffers_capacity_rounding()190 int perchannel_buffers_capacity_rounding() const { 191 return 1 << storage_.perchannel_buffers_capacity_rounding_log2; 192 } set_perchannel_buffers_capacity_rounding(int value)193 void set_perchannel_buffers_capacity_rounding(int value) { 194 // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two. 195 storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value); 196 } 197 198 private: 199 detail::MulParamsStorage<AccumScalar, DstScalar> storage_; 200 set_perchannel(bool perchannel)201 void set_perchannel(bool perchannel) { 202 if (storage_.perchannel == perchannel) { 203 return; 204 } 205 if (perchannel) { 206 RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0); 207 RUY_DCHECK_EQ(storage_.multiplier_exponent, 0); 208 } else { 209 RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr); 210 RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr); 211 } 212 storage_.perchannel = perchannel; 213 } 214 }; 215 216 namespace detail { 217 218 // Floating-point case. 219 template <typename AccumScalar, typename DstScalar> 220 struct MulParamsStorage final { 221 static_assert(std::is_floating_point<AccumScalar>::value, ""); 222 static_assert(std::is_floating_point<DstScalar>::value, ""); 223 static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), ""); 224 225 const AccumScalar* bias = nullptr; 226 DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity(); 227 DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity(); 228 ChannelDimension channel_dimension = ChannelDimension::kRow; 229 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; 230 231 // Data members that are disabled in this case are left as `static constexpr` 232 // so that one can write some generic code. 233 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = 234 nullptr; 235 static constexpr const int* multiplier_exponent_perchannel = nullptr; 236 static constexpr AccumScalar multiplier_fixedpoint = 0; 237 static constexpr int multiplier_exponent = 0; 238 static constexpr bool perchannel = false; 239 }; 240 241 // Specialization for the integer-quantized type, with down-quantization of 242 // int32 accumulators to a narrower destination scalar type. 243 template <typename DstScalar> 244 struct MulParamsStorage<std::int32_t, DstScalar> final { 245 using AccumScalar = std::int32_t; 246 static_assert(std::is_integral<DstScalar>::value, ""); 247 static_assert(sizeof(DstScalar) < sizeof(AccumScalar), ""); 248 249 const AccumScalar* bias = nullptr; 250 // union { // This used to be a union, temporarily flattened to debug a crash 251 const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; 252 // Let the default multiplier be effecively a multiplication by 1, so that 253 // the matmul behaves as a (saturating) plain integer matmul. Unfortunately 254 // 1 is not exactly representable in fixedpoint with 0 integer bits, but 255 // using the highest representable value is a sufficiently good 256 // approximation: since this specialization of MulParams is for the case 257 // where DstScalar is at least 2x narrower than MulScalar, the values 258 // for which there would be a difference will get saturated anyway. 259 AccumScalar multiplier_fixedpoint = 0; 260 //}; 261 // union { // This used to be a union, temporarily flattened to debug a crash 262 const int* multiplier_exponent_perchannel = nullptr; 263 // See the above comment about the default value of multiplier_fixedpoint. 264 int multiplier_exponent = 0; 265 // }; 266 DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); 267 DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); 268 ChannelDimension channel_dimension = ChannelDimension::kRow; 269 bool perchannel = false; 270 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; 271 }; 272 273 // Specialization used in the integer case when outputting raw int32 274 // accumulators, without down-quantization to a narrower destination scalar 275 // type. In this case, the feature of clamping destination values is not 276 // available. 277 template <> 278 struct MulParamsStorage<std::int32_t, std::int32_t> final { 279 using AccumScalar = std::int32_t; 280 using DstScalar = std::int32_t; 281 282 const AccumScalar* bias = nullptr; 283 ChannelDimension channel_dimension = ChannelDimension::kRow; 284 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; 285 286 // Data members that are disabled in this case are left as `static constexpr` 287 // so that one can write some generic code. 288 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = 289 nullptr; 290 static constexpr const int* multiplier_exponent_perchannel = nullptr; 291 static constexpr AccumScalar multiplier_fixedpoint = 0; 292 static constexpr int multiplier_exponent = 0; 293 static constexpr DstScalar clamp_min = 294 std::numeric_limits<DstScalar>::lowest(); 295 static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); 296 static constexpr bool perchannel = false; 297 }; 298 299 } // namespace detail 300 301 } // namespace ruy 302 303 #endif // RUY_RUY_MUL_PARAMS_H_ 304