1 #include <ATen/ATen.h>
2 #include <torch/library.h>
3
4 #ifdef USE_RUY_QMATMUL
5 #include <ATen/Parallel.h>
6 #include <ATen/native/quantized/cpu/RuyUtils.h>
7 #include <ruy/ruy.h>
8 #endif
9
10 namespace at {
11 namespace native {
12
13 namespace {
14
check_inputs(const Tensor & qa,const Tensor & qb)15 inline void check_inputs(const Tensor& qa, const Tensor& qb) {
16 TORCH_CHECK(
17 qa.scalar_type() == c10::kQInt8 || qa.scalar_type() == c10::kQUInt8,
18 "MatMul operands should use QInt8 or QUInt8 data types.");
19 TORCH_CHECK(
20 qa.scalar_type() == qb.scalar_type(),
21 "MatMul operands should have same data type.");
22 TORCH_CHECK(
23 qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric,
24 "Only per-tensor quantization is supported in Matmul.");
25 TORCH_CHECK(
26 qa.qscheme() == qb.qscheme(),
27 "Both inputs to Matmul must have the same quantization scheme.");
28 }
29
30 #ifdef USE_RUY_QMATMUL
31
qmatmul(const Tensor & qa,const Tensor & qb,const double output_scale,const int64_t output_zero_point)32 Tensor qmatmul(
33 const Tensor& qa,
34 const Tensor& qb,
35 const double output_scale,
36 const int64_t output_zero_point) {
37 check_inputs(qa, qb);
38
39 const int64_t num_dims = qa.dim();
40 const int64_t b_num_dims = qb.dim();
41
42 TORCH_CHECK(
43 num_dims == b_num_dims,
44 "MatMul operands should have the same dimensionality. (", num_dims,
45 " and ", b_num_dims, " provided)");
46 TORCH_CHECK(
47 num_dims >= 2,
48 "Quantized Matmul currently only supports operands which are at least 2-dimensional. (",
49 num_dims, " provided)");
50
51 const int64_t m = qa.size(num_dims - 2);
52 const int64_t k = qa.size(num_dims - 1);
53 const int64_t b_k = qb.size(num_dims - 2);
54 const int64_t n = qb.size(num_dims - 1);
55
56 TORCH_CHECK(
57 b_k == k,
58 "For Quantized Matmul, the size of tensor a (", k,
59 ") at dimension ", num_dims - 1, " must match the size of tensor b (",
60 b_k, ") at dimension ", num_dims - 2, ".");
61
62 std::vector<int64_t> out_size_vec(num_dims);
63 size_t num_matmuls = 1;
64 for (int64_t i = 0; i < num_dims - 2; i++) {
65 const int64_t dim = qa.size(i);
66 const int64_t qb_dim = qb.size(i);
67
68 TORCH_CHECK(
69 dim == qb_dim,
70 "For Quantized Matmul, the size of tensor a (", dim,
71 ") must match the size of tensor b (", qb_dim,
72 ") at dimension ", i);
73
74 out_size_vec[i] = dim;
75 num_matmuls *= dim;
76 }
77 out_size_vec[num_dims - 2] = m;
78 out_size_vec[num_dims - 1] = n;
79
80 Tensor out = at::_empty_affine_quantized(
81 IntArrayRef(out_size_vec),
82 at::device(kCPU)
83 .dtype(qa.scalar_type())
84 .memory_format(qa.suggest_memory_format()),
85 output_scale,
86 output_zero_point,
87 std::nullopt);
88
89 const Tensor& qa_contig = qa.contiguous();
90 const Tensor& qb_contig = qb.contiguous();
91
92 AT_DISPATCH_QINT_BYTE_TYPES(qa.scalar_type(), "qmatmul", [&] {
93 using underlying_t = typename scalar_t::underlying;
94
95 const underlying_t* qa_data = reinterpret_cast<const underlying_t*>(
96 qa_contig.data_ptr<scalar_t>());
97 const underlying_t* qb_data = reinterpret_cast<const underlying_t*>(
98 qb_contig.data_ptr<scalar_t>());
99 underlying_t* out_data =
100 reinterpret_cast<underlying_t*>(out.data_ptr<scalar_t>());
101
102 const size_t qa_stride = m * k;
103 const size_t qb_stride = k * n;
104 const size_t out_stride = m * n;
105
106 auto matmuls = [&](int64_t begin, int64_t end) {
107
108 ruy::Matrix<underlying_t> qa_matrix;
109 ruy::MakeSimpleLayout(
110 m, k, ruy::Order::kRowMajor, qa_matrix.mutable_layout());
111 qa_matrix.set_zero_point(qa.q_zero_point());
112
113 ruy::Matrix<underlying_t> qb_matrix;
114 ruy::MakeSimpleLayout(
115 k, n, ruy::Order::kRowMajor, qb_matrix.mutable_layout());
116 qb_matrix.set_zero_point(qb.q_zero_point());
117
118 ruy::Matrix<underlying_t> out_matrix;
119 ruy::MakeSimpleLayout(
120 m, n, ruy::Order::kRowMajor, out_matrix.mutable_layout());
121 out_matrix.set_zero_point(output_zero_point);
122
123 // Requantization explanation:
124 // https://github.com/google/gemmlowp/blob/e844ffd17118c1e17d94e1ba4354c075a4577b88/doc/quantization.md
125 const double requantization_scale_inv =
126 (qa.q_scale() * qb.q_scale()) / output_scale;
127
128 ruy::MulParams<int32_t, underlying_t> mul_params;
129
130 int multiplier_fixedpoint;
131 int multiplier_exponent;
132 ruy_utils::quantize_multiplier(requantization_scale_inv,
133 &multiplier_fixedpoint,
134 &multiplier_exponent);
135 mul_params.set_multiplier_fixedpoint(multiplier_fixedpoint);
136 mul_params.set_multiplier_exponent(multiplier_exponent);
137
138 const underlying_t* qa_subtensor = qa_data + begin * qa_stride;
139 const underlying_t* qb_subtensor = qb_data + begin * qb_stride;
140 underlying_t* out_subtensor = out_data + begin * out_stride;
141
142 for (int64_t i = begin; i < end; i++) {
143 qa_matrix.set_data(qa_subtensor);
144 qb_matrix.set_data(qb_subtensor);
145 out_matrix.set_data(out_subtensor);
146 ruy::Mul(qa_matrix,
147 qb_matrix,
148 mul_params,
149 ruy_utils::get_ruy_context(),
150 &out_matrix);
151
152 qa_subtensor += qa_stride;
153 qb_subtensor += qb_stride;
154 out_subtensor += out_stride;
155 }
156 };
157
158 at::parallel_for(0, num_matmuls, 1, matmuls);
159 });
160
161 return out;
162 }
163
164 #else // ifdef USE_RUY_QMATMUL
165
qmatmul(const Tensor & qa,const Tensor & qb,const double output_scale,const int64_t output_zero_point)166 Tensor qmatmul(
167 const Tensor& qa,
168 const Tensor& qb,
169 const double output_scale,
170 const int64_t output_zero_point) {
171 check_inputs(qa, qb);
172 Tensor ra = at::dequantize(qa);
173 Tensor rb = at::dequantize(qb);
174 Tensor rc = at::matmul(ra, rb);
175 return at::quantize_per_tensor(
176 rc, output_scale, output_zero_point, qa.scalar_type());
177 }
178
179 #endif // ifdef USE_RUY_QMATMUL
180
TORCH_LIBRARY_IMPL(quantized,QuantizedCPU,m)181 TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
182 m.impl(TORCH_SELECTIVE_NAME("quantized::matmul"), TORCH_FN(qmatmul));
183 }
184
185 } // namespace
186
187 } // namespace native
188 } // namespace at
189