xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qmatmul.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <torch/library.h>
3 
4 #ifdef USE_RUY_QMATMUL
5 #include <ATen/Parallel.h>
6 #include <ATen/native/quantized/cpu/RuyUtils.h>
7 #include <ruy/ruy.h>
8 #endif
9 
10 namespace at {
11 namespace native {
12 
13 namespace {
14 
check_inputs(const Tensor & qa,const Tensor & qb)15 inline void check_inputs(const Tensor& qa, const Tensor& qb) {
16   TORCH_CHECK(
17       qa.scalar_type() == c10::kQInt8 || qa.scalar_type() == c10::kQUInt8,
18       "MatMul operands should use QInt8 or QUInt8 data types.");
19   TORCH_CHECK(
20       qa.scalar_type() == qb.scalar_type(),
21       "MatMul operands should have same data type.");
22   TORCH_CHECK(
23       qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric,
24       "Only per-tensor quantization is supported in Matmul.");
25   TORCH_CHECK(
26       qa.qscheme() == qb.qscheme(),
27       "Both inputs to Matmul must have the same quantization scheme.");
28 }
29 
30 #ifdef USE_RUY_QMATMUL
31 
qmatmul(const Tensor & qa,const Tensor & qb,const double output_scale,const int64_t output_zero_point)32 Tensor qmatmul(
33     const Tensor& qa,
34     const Tensor& qb,
35     const double output_scale,
36     const int64_t output_zero_point) {
37   check_inputs(qa, qb);
38 
39   const int64_t num_dims = qa.dim();
40   const int64_t b_num_dims = qb.dim();
41 
42   TORCH_CHECK(
43       num_dims == b_num_dims,
44       "MatMul operands should have the same dimensionality. (", num_dims,
45       " and ", b_num_dims, " provided)");
46   TORCH_CHECK(
47       num_dims >= 2,
48       "Quantized Matmul currently only supports operands which are at least 2-dimensional. (",
49       num_dims, " provided)");
50 
51   const int64_t m = qa.size(num_dims - 2);
52   const int64_t k = qa.size(num_dims - 1);
53   const int64_t b_k = qb.size(num_dims - 2);
54   const int64_t n = qb.size(num_dims - 1);
55 
56   TORCH_CHECK(
57       b_k == k,
58       "For Quantized Matmul, the size of tensor a (", k,
59       ") at dimension ", num_dims - 1, " must match the size of tensor b (",
60       b_k, ") at dimension ", num_dims - 2, ".");
61 
62   std::vector<int64_t> out_size_vec(num_dims);
63   size_t num_matmuls = 1;
64   for (int64_t i = 0; i < num_dims - 2; i++) {
65     const int64_t dim = qa.size(i);
66     const int64_t qb_dim = qb.size(i);
67 
68     TORCH_CHECK(
69         dim == qb_dim,
70         "For Quantized Matmul, the size of tensor a (", dim,
71         ") must match the size of tensor b (", qb_dim,
72         ") at dimension ", i);
73 
74     out_size_vec[i] = dim;
75     num_matmuls *= dim;
76   }
77   out_size_vec[num_dims - 2] = m;
78   out_size_vec[num_dims - 1] = n;
79 
80   Tensor out = at::_empty_affine_quantized(
81       IntArrayRef(out_size_vec),
82       at::device(kCPU)
83           .dtype(qa.scalar_type())
84           .memory_format(qa.suggest_memory_format()),
85       output_scale,
86       output_zero_point,
87       std::nullopt);
88 
89   const Tensor& qa_contig = qa.contiguous();
90   const Tensor& qb_contig = qb.contiguous();
91 
92   AT_DISPATCH_QINT_BYTE_TYPES(qa.scalar_type(), "qmatmul", [&] {
93     using underlying_t = typename scalar_t::underlying;
94 
95     const underlying_t* qa_data = reinterpret_cast<const underlying_t*>(
96         qa_contig.data_ptr<scalar_t>());
97     const underlying_t* qb_data = reinterpret_cast<const underlying_t*>(
98         qb_contig.data_ptr<scalar_t>());
99     underlying_t* out_data =
100         reinterpret_cast<underlying_t*>(out.data_ptr<scalar_t>());
101 
102     const size_t qa_stride = m * k;
103     const size_t qb_stride = k * n;
104     const size_t out_stride = m * n;
105 
106     auto matmuls = [&](int64_t begin, int64_t end) {
107 
108       ruy::Matrix<underlying_t> qa_matrix;
109       ruy::MakeSimpleLayout(
110           m, k, ruy::Order::kRowMajor, qa_matrix.mutable_layout());
111       qa_matrix.set_zero_point(qa.q_zero_point());
112 
113       ruy::Matrix<underlying_t> qb_matrix;
114       ruy::MakeSimpleLayout(
115           k, n, ruy::Order::kRowMajor, qb_matrix.mutable_layout());
116       qb_matrix.set_zero_point(qb.q_zero_point());
117 
118       ruy::Matrix<underlying_t> out_matrix;
119       ruy::MakeSimpleLayout(
120           m, n, ruy::Order::kRowMajor, out_matrix.mutable_layout());
121       out_matrix.set_zero_point(output_zero_point);
122 
123       // Requantization explanation:
124       // https://github.com/google/gemmlowp/blob/e844ffd17118c1e17d94e1ba4354c075a4577b88/doc/quantization.md
125       const double requantization_scale_inv =
126           (qa.q_scale() * qb.q_scale()) / output_scale;
127 
128       ruy::MulParams<int32_t, underlying_t> mul_params;
129 
130       int multiplier_fixedpoint;
131       int multiplier_exponent;
132       ruy_utils::quantize_multiplier(requantization_scale_inv,
133                                      &multiplier_fixedpoint,
134                                      &multiplier_exponent);
135       mul_params.set_multiplier_fixedpoint(multiplier_fixedpoint);
136       mul_params.set_multiplier_exponent(multiplier_exponent);
137 
138       const underlying_t* qa_subtensor = qa_data + begin * qa_stride;
139       const underlying_t* qb_subtensor = qb_data + begin * qb_stride;
140       underlying_t* out_subtensor = out_data + begin * out_stride;
141 
142       for (int64_t i = begin; i < end; i++) {
143         qa_matrix.set_data(qa_subtensor);
144         qb_matrix.set_data(qb_subtensor);
145         out_matrix.set_data(out_subtensor);
146         ruy::Mul(qa_matrix,
147                  qb_matrix,
148                  mul_params,
149                  ruy_utils::get_ruy_context(),
150                  &out_matrix);
151 
152         qa_subtensor += qa_stride;
153         qb_subtensor += qb_stride;
154         out_subtensor += out_stride;
155       }
156     };
157 
158     at::parallel_for(0, num_matmuls, 1, matmuls);
159   });
160 
161   return out;
162 }
163 
164 #else // ifdef USE_RUY_QMATMUL
165 
qmatmul(const Tensor & qa,const Tensor & qb,const double output_scale,const int64_t output_zero_point)166 Tensor qmatmul(
167     const Tensor& qa,
168     const Tensor& qb,
169     const double output_scale,
170     const int64_t output_zero_point) {
171   check_inputs(qa, qb);
172   Tensor ra = at::dequantize(qa);
173   Tensor rb = at::dequantize(qb);
174   Tensor rc = at::matmul(ra, rb);
175   return at::quantize_per_tensor(
176       rc, output_scale, output_zero_point, qa.scalar_type());
177 }
178 
179 #endif // ifdef USE_RUY_QMATMUL
180 
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)181 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
182   m.impl(TORCH_SELECTIVE_NAME("quantized::matmul"), TORCH_FN(qmatmul));
183 }
184 
185 } // namespace
186 
187 } // namespace native
188 } // namespace at
189