xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_quantization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/native/quantized/PackedParams.h>
4 #include <test/cpp/tensorexpr/test_base.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/tensorexpr/kernel.h>
8 #include <torch/csrc/jit/tensorexpr/loopnest.h>
9 #include <torch/csrc/jit/tensorexpr/tensor.h>
10 #include <torch/csrc/jit/testing/file_check.h>
11 #include <torch/torch.h>
12 #include <cmath>
13 #include <sstream>
14 #include "torch/csrc/jit/tensorexpr/eval.h"
15 #include "torch/csrc/jit/tensorexpr/ir.h"
16 
17 namespace torch {
18 namespace jit {
19 
20 using namespace torch::jit::tensorexpr;
21 using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
22 using namespace torch::indexing;
23 using namespace torch::jit::tensorexpr;
24 
25 class Quantization : public ::testing::Test {
26  public:
SetUp()27   void SetUp() override {
28     getTEMustUseLLVMOnCPU() = false;
29   }
30 };
31 
TEST_F(Quantization,QuantDequantInt8)32 TEST_F(Quantization, QuantDequantInt8) {
33   const auto graph_string = R"IR(
34       graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
35         %2 : int = prim::Constant[value=12]()
36         %3 : int = prim::Constant[value=13]()
37         %4 : float = prim::Constant[value=0.1]()
38         %q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
39         %6 : Float(2, 2) = aten::dequantize(%q.1)
40         return (%6))IR";
41   auto graph = std::make_shared<Graph>();
42   parseIR(graph_string, &*graph);
43 
44   auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
45   auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8);
46   auto y_expected = at::dequantize(q);
47   TensorExprKernel k(graph);
48   std::vector<at::Tensor> inputs = {x};
49   StmtPtr s = k.getCodeGenStmt();
50 
51   std::vector<IValue> stack = fmap<IValue>(inputs);
52   k.run(stack);
53   auto y = stack[0].toTensor();
54   bool check = at::allclose(y_expected, y);
55   if (!check) {
56     std::cout << "y_expected:\n" << y_expected << std::endl;
57     std::cout << "y:\n" << y << std::endl;
58   }
59   TORCH_CHECK_EQ(check, 1);
60 }
61 
TEST_F(Quantization,QuantDequantUInt8)62 TEST_F(Quantization, QuantDequantUInt8) {
63   const auto graph_string = R"IR(
64       graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
65         %2 : int = prim::Constant[value=13]()
66         %3 : int = prim::Constant[value=122]()
67         %4 : float = prim::Constant[value=0.1]()
68         %q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
69         %6 : Float(2, 2) = aten::dequantize(%q.1)
70         return (%6))IR";
71   auto graph = std::make_shared<Graph>();
72   parseIR(graph_string, &*graph);
73 
74   auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
75   auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
76   auto y_expected = at::dequantize(q);
77   TensorExprKernel k(graph);
78   std::vector<at::Tensor> inputs = {x};
79   StmtPtr s = k.getCodeGenStmt();
80 
81   std::vector<IValue> stack = fmap<IValue>(inputs);
82   k.run(stack);
83   auto y = stack[0].toTensor();
84   bool check = at::allclose(y_expected, y);
85   if (!check) {
86     std::cout << "y_expected:\n" << y_expected << std::endl;
87     std::cout << "y:\n" << y << std::endl;
88   }
89   TORCH_CHECK_EQ(check, 1);
90 }
91 
TEST_F(Quantization,QuantDequantUInt8_NLC)92 TEST_F(Quantization, QuantDequantUInt8_NLC) {
93   const auto graph_string = R"IR(
94       graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)):
95         %2 : int = prim::Constant[value=13]()
96         %3 : int = prim::Constant[value=122]()
97         %4 : float = prim::Constant[value=0.1]()
98         %q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
99         %6 : Float(1, 2, 2) = aten::dequantize(%q.1)
100         return (%6))IR";
101   auto graph = std::make_shared<Graph>();
102   parseIR(graph_string, &*graph);
103 
104   auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
105   x.unsafeGetTensorImpl()->set_sizes_and_strides(
106       std::initializer_list<int64_t>{1, 2, 2}, {4, 1, 2});
107   auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
108   auto y_expected = at::dequantize(q);
109   TensorExprKernel k(graph);
110   std::vector<at::Tensor> inputs = {x};
111   StmtPtr s = k.getCodeGenStmt();
112 
113   std::vector<IValue> stack = fmap<IValue>(inputs);
114   k.run(stack);
115   auto y = stack[0].toTensor();
116   bool check = at::allclose(y_expected, y);
117   if (!check) {
118     std::cout << "x:\n" << x << std::endl;
119     std::cout << "y_expected:\n" << y_expected << std::endl;
120     std::cout << "y:\n" << y << std::endl;
121   }
122   TORCH_CHECK_EQ(check, 1);
123 }
124 
quantized_add(at::Tensor x1,at::Tensor x2,double scale,int64_t zero)125 at::Tensor quantized_add(
126     at::Tensor x1,
127     at::Tensor x2,
128     double scale,
129     int64_t zero) {
130   const auto qadd_op =
131       c10::Dispatcher::singleton()
132           .findSchemaOrThrow("quantized::add", "")
133           .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
134   return qadd_op.call(x1, x2, scale, zero);
135 }
136 
TEST_F(Quantization,QuantAddDequantInt8)137 TEST_F(Quantization, QuantAddDequantInt8) {
138   const auto graph_string = R"IR(
139       graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
140         %2 : int = prim::Constant[value=12]()
141         %qz1 : int = prim::Constant[value=13]()
142         %qs1 : float = prim::Constant[value=0.1]()
143         %qz2 : int = prim::Constant[value=13]()
144         %qs2 : float = prim::Constant[value=0.1]()
145         %qza : int = prim::Constant[value=13]()
146         %qsa : float = prim::Constant[value=0.1]()
147         %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
148         %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
149         %qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
150         %6 : Float(2, 2) = aten::dequantize(%qa)
151         return (%6))IR";
152   auto graph = std::make_shared<Graph>();
153   parseIR(graph_string, &*graph);
154 
155   auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
156   auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
157   auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8);
158   auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8);
159   auto qa = quantized_add(q1, q2, 0.1f, 13);
160   auto y_expected = at::dequantize(qa);
161   TensorExprKernel k(graph);
162   std::vector<at::Tensor> inputs = {x1, x2};
163   StmtPtr s = k.getCodeGenStmt();
164 
165   std::vector<IValue> stack = fmap<IValue>(inputs);
166   k.run(stack);
167   auto y = stack[0].toTensor();
168   bool check = at::allclose(y_expected, y);
169   if (!check) {
170     std::cout << "x1:\n" << x1 << std::endl;
171     std::cout << "q1:\n" << q1 << std::endl;
172     std::cout << "x2:\n" << x2 << std::endl;
173     std::cout << "q2:\n" << q2 << std::endl;
174     std::cout << "y_expected:\n" << y_expected << std::endl;
175     std::cout << "y:\n" << y << std::endl;
176   }
177   TORCH_CHECK_EQ(check, 1);
178 }
179 
TEST_F(Quantization,QuantAddDequantUInt8)180 TEST_F(Quantization, QuantAddDequantUInt8) {
181   const auto graph_string = R"IR(
182       graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
183         %2 : int = prim::Constant[value=13]()
184         %qz1 : int = prim::Constant[value=13]()
185         %qs1 : float = prim::Constant[value=0.1]()
186         %qz2 : int = prim::Constant[value=13]()
187         %qs2 : float = prim::Constant[value=0.1]()
188         %qza : int = prim::Constant[value=13]()
189         %qsa : float = prim::Constant[value=0.1]()
190         %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
191         %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
192         %qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
193         %6 : Float(2, 2) = aten::dequantize(%qa)
194         return (%6))IR";
195   auto graph = std::make_shared<Graph>();
196   parseIR(graph_string, &*graph);
197 
198   auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
199   auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
200   auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
201   auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
202   auto qa = quantized_add(q1, q2, 0.1f, 13);
203   auto y_expected = at::dequantize(qa);
204 
205   TensorExprKernel k(graph);
206   std::vector<at::Tensor> inputs = {x1, x2};
207   StmtPtr s = k.getCodeGenStmt();
208 
209   std::vector<IValue> stack = fmap<IValue>(inputs);
210   k.run(stack);
211   auto y = stack[0].toTensor();
212   bool check = at::allclose(y_expected, y);
213   if (!check) {
214     std::cout << "x1:\n" << x1 << std::endl;
215     std::cout << "q1:\n" << q1 << std::endl;
216     std::cout << "x2:\n" << x2 << std::endl;
217     std::cout << "q2:\n" << q2 << std::endl;
218     std::cout << "y_expected:\n" << y_expected << std::endl;
219     std::cout << "y:\n" << y << std::endl;
220   }
221   TORCH_CHECK_EQ(check, 1);
222 }
223 
TEST_F(Quantization,QuantSigmoidDequantUInt8)224 TEST_F(Quantization, QuantSigmoidDequantUInt8) {
225   const auto graph_string = R"IR(
226       graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)):
227         %2 : int = prim::Constant[value=13]()
228         %qz1 : int = prim::Constant[value=13]()
229         %qs1 : float = prim::Constant[value=0.1]()
230         %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
231         %qa : QUInt8(2, 2) = aten::sigmoid(%q1)
232         %6 : Float(2, 2) = aten::dequantize(%qa)
233         return (%6))IR";
234   auto graph = std::make_shared<Graph>();
235   parseIR(graph_string, &*graph);
236 
237   auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
238   auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
239   auto qs = at::sigmoid(q1);
240   auto y_expected = at::dequantize(qs);
241 
242   TensorExprKernel k(graph);
243   std::vector<at::Tensor> inputs = {x1};
244   StmtPtr s = k.getCodeGenStmt();
245 
246   std::vector<IValue> stack = fmap<IValue>(inputs);
247   k.run(stack);
248   auto y = stack[0].toTensor();
249   bool check = at::allclose(y_expected, y);
250   if (!check) {
251     std::cout << "x1:\n" << x1 << std::endl;
252     std::cout << "q1:\n" << q1 << std::endl;
253     std::cout << "qs:\n" << qs << std::endl;
254     std::cout << "y_expected:\n" << y_expected << std::endl;
255     std::cout << "y:\n" << y << std::endl;
256   }
257   TORCH_CHECK_EQ(check, 1);
258 }
259 
quantized_mul(at::Tensor x1,at::Tensor x2,double scale,int64_t zero)260 at::Tensor quantized_mul(
261     at::Tensor x1,
262     at::Tensor x2,
263     double scale,
264     int64_t zero) {
265   const auto op =
266       c10::Dispatcher::singleton()
267           .findSchemaOrThrow("quantized::mul", "")
268           .typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
269   return op.call(x1, x2, scale, zero);
270 }
271 
TEST_F(Quantization,QuantMulDequantUInt8)272 TEST_F(Quantization, QuantMulDequantUInt8) {
273   const auto graph_string = R"IR(
274       graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
275         %2 : int = prim::Constant[value=13]()
276         %qz1 : int = prim::Constant[value=13]()
277         %qs1 : float = prim::Constant[value=0.1]()
278         %qz2 : int = prim::Constant[value=13]()
279         %qs2 : float = prim::Constant[value=0.1]()
280         %qza : int = prim::Constant[value=13]()
281         %qsa : float = prim::Constant[value=0.1]()
282         %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
283         %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
284         %qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza)
285         %6 : Float(2, 2) = aten::dequantize(%qa)
286         return (%6))IR";
287   auto graph = std::make_shared<Graph>();
288   parseIR(graph_string, &*graph);
289 
290   auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
291   auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
292   auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
293   auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
294   auto qa = quantized_mul(q1, q2, 0.1f, 13);
295   auto y_expected = at::dequantize(qa);
296 
297   TensorExprKernel k(graph);
298   std::vector<at::Tensor> inputs = {x1, x2};
299   StmtPtr s = k.getCodeGenStmt();
300 
301   std::vector<IValue> stack = fmap<IValue>(inputs);
302   k.run(stack);
303   auto y = stack[0].toTensor();
304   bool check = at::allclose(y_expected, y);
305   if (!check) {
306     std::cout << "x1:\n" << x1 << std::endl;
307     std::cout << "q1:\n" << q1 << std::endl;
308     std::cout << "x2:\n" << x2 << std::endl;
309     std::cout << "q2:\n" << q2 << std::endl;
310     std::cout << "y_expected:\n" << y_expected << std::endl;
311     std::cout << "y:\n" << y << std::endl;
312   }
313   TORCH_CHECK_EQ(check, 1);
314 }
315 
TEST_F(Quantization,QuantUpsampleNearst2dDequantUInt8)316 TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) {
317   const auto graph_string = R"IR(
318       graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)):
319         %2 : int = prim::Constant[value=13]()
320         %4 : NoneType = prim::Constant()
321         %3 : int[] = prim::Constant[value=[6, 6]]()
322         %qz : int = prim::Constant[value=13]()
323         %qs : float = prim::Constant[value=0.1]()
324         %q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2)
325         %qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4)
326         %6 : Float(1, 1, 6, 6) = aten::dequantize(%qu)
327         return (%6))IR";
328   auto graph = std::make_shared<Graph>();
329   parseIR(graph_string, &*graph);
330 
331   auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
332   auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
333   auto qu = at::upsample_nearest2d(q, {6, 6});
334   auto y_expected = at::dequantize(qu);
335 
336   TensorExprKernel k(graph);
337   std::vector<at::Tensor> inputs = {x};
338   StmtPtr s = k.getCodeGenStmt();
339 
340   std::vector<IValue> stack = fmap<IValue>(inputs);
341   k.run(stack);
342   auto y = stack[0].toTensor();
343   bool check = at::allclose(y_expected, y);
344   if (!check) {
345     std::cout << "x:\n" << x << std::endl;
346     std::cout << "q:\n" << q << std::endl;
347     std::cout << "qu:\n" << qu << std::endl;
348     std::cout << "y_expected:\n" << y_expected << std::endl;
349     std::cout << "y:\n" << y << std::endl;
350   }
351   TORCH_CHECK_EQ(check, 1);
352 }
353 
TEST_F(Quantization,UpsampleNearst2d)354 TEST_F(Quantization, UpsampleNearst2d) {
355   const auto graph_string = R"IR(
356       graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
357         %4 : NoneType = prim::Constant()
358         %3 : int[] = prim::Constant[value=[4, 4]]()
359         %u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4)
360         return (%u))IR";
361   auto graph = std::make_shared<Graph>();
362   parseIR(graph_string, &*graph);
363 
364   auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
365   auto y_expected = at::upsample_nearest2d(x, {4, 4});
366 
367   TensorExprKernel k(graph);
368   std::vector<at::Tensor> inputs = {x};
369   StmtPtr s = k.getCodeGenStmt();
370 
371   std::vector<IValue> stack = fmap<IValue>(inputs);
372   k.run(stack);
373   auto y = stack[0].toTensor();
374   bool check = at::allclose(y_expected, y);
375   if (!check) {
376     std::cout << "x:\n" << x << std::endl;
377     std::cout << "y_expected:\n" << y_expected << std::endl;
378     std::cout << "y:\n" << y << std::endl;
379   }
380   TORCH_CHECK_EQ(check, 1);
381 }
382 
quantized_cat(c10::List<at::Tensor> const & xs,int64_t dim,double scale,int64_t zero)383 at::Tensor quantized_cat(
384     c10::List<at::Tensor> const& xs,
385     int64_t dim,
386     double scale,
387     int64_t zero) {
388   const auto op = c10::Dispatcher::singleton()
389                       .findSchemaOrThrow("quantized::cat", "")
390                       .typed<at::Tensor(
391                           c10::List<at::Tensor> const&,
392                           int64_t,
393                           std::optional<double>,
394                           std::optional<int64_t>)>();
395   return op.redispatch(
396       DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero);
397 }
398 
TEST_F(Quantization,QuantCatDequantUInt8)399 TEST_F(Quantization, QuantCatDequantUInt8) {
400   const auto graph_string = R"IR(
401       graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
402         %qdt : int = prim::Constant[value=13]()
403         %qxz : int = prim::Constant[value=13]()
404         %qxs : float = prim::Constant[value=0.1]()
405         %qyz : int = prim::Constant[value=16]()
406         %qys : float = prim::Constant[value=0.15]()
407         %qzz : int = prim::Constant[value=19]()
408         %qzs : float = prim::Constant[value=0.2]()
409         %qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt)
410         %qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt)
411         %qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt)
412         %catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz)
413         %catd : int = prim::Constant[value=0]()
414         %qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz)
415         %cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat)
416         return (%cat))IR";
417   auto graph = std::make_shared<Graph>();
418   parseIR(graph_string, &*graph);
419 
420   auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
421   auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
422   auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
423   auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
424   auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8);
425   auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8);
426   auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13);
427   auto expected = at::dequantize(qcat);
428 
429   TensorExprKernel k(graph);
430   std::vector<at::Tensor> inputs = {x, y, z};
431   StmtPtr s = k.getCodeGenStmt();
432 
433   std::vector<IValue> stack = fmap<IValue>(inputs);
434   k.run(stack);
435   auto result = stack[0].toTensor();
436   bool check = at::allclose(expected, result);
437   if (!check) {
438     std::cout << "x:\n" << x << std::endl;
439     std::cout << "y:\n" << y << std::endl;
440     std::cout << "z:\n" << z << std::endl;
441     std::cout << "qx:\n" << qx << std::endl;
442     std::cout << "qy:\n" << qy << std::endl;
443     std::cout << "qz:\n" << qz << std::endl;
444     std::cout << "qcat:\n" << qcat << std::endl;
445     std::cout << "expected:\n" << expected << std::endl;
446     std::cout << "result:\n" << result << std::endl;
447   }
448   TORCH_CHECK_EQ(check, 1);
449 }
450 
451 } // namespace jit
452 } // namespace torch
453