xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/RuyUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_RUY_QMATMUL
2 
3 #include <ATen/ATen.h>
4 #include <ATen/native/quantized/cpu/RuyUtils.h>
5 
6 namespace at {
7 namespace native {
8 namespace ruy_utils {
9 
10 static thread_local ruy::Context context;
11 
get_ruy_context()12 ruy::Context* get_ruy_context() {
13   return &context;
14 }
15 
16 // Adopted from Ruy:
17 // https://github.com/google/ruy/blob/2d950b3bfa7ebfbe7a97ecb44b1cc4da5ac1d6f0/ruy/test.h#L1602
quantize_multiplier(double scale,int * multiplier_fixedpoint,int * multiplier_exponent)18 void quantize_multiplier(double scale,
19                          int* multiplier_fixedpoint,
20                          int* multiplier_exponent) {
21   TORCH_CHECK(scale > 0, "Quantization scale (", scale, ") must be positive.");
22   const double q = std::frexp(scale, multiplier_exponent);
23   auto q_fixed = static_cast<std::int64_t>(std::round(q * (1ll << 31)));
24   TORCH_CHECK(q_fixed <= (1ll << 31));
25   if (q_fixed == (1ll << 31)) {
26     q_fixed /= 2;
27     ++*multiplier_exponent;
28   }
29   TORCH_CHECK(q_fixed <= std::numeric_limits<std::int32_t>::max());
30   *multiplier_fixedpoint = static_cast<std::int32_t>(q_fixed);
31 }
32 
33 } // namespace ruy_utils
34 } // namespace native
35 } // namespace
36 
37 #endif // USE_RUY_QMATMUL
38