xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/quantization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/ScalarType.h>
2 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
4 #include <torch/csrc/jit/tensorexpr/operators/pointwise.h>
5 #include <torch/csrc/jit/tensorexpr/operators/quantization.h>
6 
7 using namespace torch::jit::tensorexpr;
8 
9 namespace torch::jit::tensorexpr {
10 namespace {
_pair_int(ArgValue v)11 std::vector<int64_t> _pair_int(ArgValue v) {
12   if (auto t = std::get_if<IntList>(&v)) {
13     return {(*t)[0], (*t)[1]};
14   }
15   auto i = std::get<int64_t>(v);
16   return {i, i};
17 }
18 } // namespace
19 
immQScale(const BufHandle & qx)20 double immQScale(const BufHandle& qx) {
21   TORCH_INTERNAL_ASSERT(
22       qx.node()->qscale(), buildErrorMessage("Expects BufHandle with qscale"));
23   return to<DoubleImm>(IRSimplifier::simplify(qx.node()->qscale()))->value();
24 }
25 
immQZero(const BufHandle & qx)26 int64_t immQZero(const BufHandle& qx) {
27   TORCH_INTERNAL_ASSERT(
28       qx.node()->qzero(), buildErrorMessage("Expects BufHandle with qzero"));
29   return to<LongImm>(IRSimplifier::simplify(qx.node()->qzero()))->value();
30 }
31 
immQDType(const BufHandle & qx)32 ScalarType immQDType(const BufHandle& qx) {
33   return qx.dtype().scalar_type();
34 }
35 
isQuantized(const BufHandle & qx)36 bool isQuantized(const BufHandle& qx) {
37   return qx.node()->qscale() && qx.node()->qzero();
38 }
39 
makeQBufHandleChannelsLast(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const ExprPtr & qscale,const ExprPtr & qzero)40 static BufHandle makeQBufHandleChannelsLast(
41     const std::string& name,
42     const std::vector<ExprHandle>& dims,
43     Dtype dtype,
44     const ExprPtr& qscale,
45     const ExprPtr& qzero) {
46   BufHandle ResultBuf(name, dims, dtype);
47   ResultBuf.node()->set_qscale(qscale);
48   ResultBuf.node()->set_qzero(qzero);
49   ResultBuf.node()->set_strides(make_channels_last_strides(dims));
50   return ResultBuf;
51 }
52 
makeQBufHandleChannelsLast(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const double qscale,const int64_t qzero)53 static BufHandle makeQBufHandleChannelsLast(
54     const std::string& name,
55     const std::vector<ExprHandle>& dims,
56     Dtype dtype,
57     const double qscale,
58     const int64_t qzero) {
59   return makeQBufHandleChannelsLast(
60       name,
61       dims,
62       dtype,
63       DoubleImm::make(qscale).node(),
64       LongImm::make(qzero).node());
65 }
66 
makeQBufHandleContiguous(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const ExprPtr & qscale,const ExprPtr & qzero)67 static BufHandle makeQBufHandleContiguous(
68     const std::string& name,
69     const std::vector<ExprHandle>& dims,
70     Dtype dtype,
71     const ExprPtr& qscale,
72     const ExprPtr& qzero) {
73   BufHandle ResultBuf(name, dims, dtype);
74   ResultBuf.node()->set_qscale(qscale);
75   ResultBuf.node()->set_qzero(qzero);
76   ResultBuf.node()->set_strides(make_contiguous_strides(dims));
77   return ResultBuf;
78 }
79 
makeQBufHandleContiguous(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const double qscale,const int64_t qzero)80 static BufHandle makeQBufHandleContiguous(
81     const std::string& name,
82     const std::vector<ExprHandle>& dims,
83     Dtype dtype,
84     const double qscale,
85     const int64_t qzero) {
86   return makeQBufHandleContiguous(
87       name,
88       dims,
89       dtype,
90       DoubleImm::make(qscale).node(),
91       LongImm::make(qzero).node());
92 }
93 
isChannelsLast(const BufHandle & buf)94 static bool isChannelsLast(const BufHandle& buf) {
95   const auto& strides = buf.node()->strides();
96   const auto& dims = buf.node()->dims();
97   const auto rank = dims.size();
98   if (rank < 3) {
99     return false;
100   }
101   auto dimsC = to<LongImm>(IRSimplifier::simplify(dims[1]))->value();
102   auto stridesC = to<LongImm>(IRSimplifier::simplify(strides[1]))->value();
103   auto stridesLast =
104       to<LongImm>(IRSimplifier::simplify(strides[rank - 1]))->value();
105 
106   return ((stridesLast == dimsC) && (stridesC == 1));
107 }
108 
quant(const ExprHandle & x,Dtype out_dtype,ExprHandle qscale,ExprHandle qzero)109 static ExprHandle quant(
110     const ExprHandle& x,
111     Dtype out_dtype,
112     ExprHandle qscale,
113     ExprHandle qzero) {
114   auto promoted_qscale =
115       promoteToDtype(std::move(qscale), x.dtype().scalar_type());
116   auto promoted_qzero =
117       promoteToDtype(std::move(qzero), x.dtype().scalar_type());
118   return promoteToDtype(
119       x / promoted_qscale + promoted_qzero + FloatImm::make(0.5f),
120       out_dtype.scalar_type());
121 }
122 
dequant(ExprHandle qx,Dtype out_dtype,ExprHandle qscale,ExprHandle qzero)123 static ExprHandle dequant(
124     ExprHandle qx,
125     Dtype out_dtype,
126     ExprHandle qscale,
127     ExprHandle qzero) {
128   auto qx_promoted = promoteToDtype(std::move(qx), out_dtype.scalar_type());
129   auto qscale_promoted =
130       promoteToDtype(std::move(qscale), out_dtype.scalar_type());
131   auto qzero_promoted =
132       promoteToDtype(std::move(qzero), out_dtype.scalar_type());
133   return promoteToDtype(
134       (qx_promoted - qzero_promoted) * qscale_promoted,
135       out_dtype.scalar_type());
136 }
137 
computeQuantizePerTensor(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> &,at::Device)138 Tensor computeQuantizePerTensor(
139     const std::vector<ArgValue>& inputs,
140     const std::vector<ExprHandle>& outputShape,
141     const std::vector<ExprHandle>& outputStrides,
142     const std::optional<ScalarType>&,
143     at::Device) {
144   std::vector<VarPtr> vars;
145   std::vector<ExprHandle> indices;
146   for (const auto& os : outputShape) {
147     auto var = alloc<Var>("", os.node()->dtype());
148     vars.push_back(var);
149     indices.push_back(VarHandle(var));
150   }
151 
152   ExprHandle qscale = constant(inputs[1]);
153   ExprHandle qzero = constant(inputs[2]);
154 
155   const auto dtype = [](auto qdtype) {
156     if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
157       return Dtype(ScalarType::QInt8);
158     } else if (static_cast<int64_t>(ScalarType::QUInt8) == qdtype) {
159       return Dtype(ScalarType::QUInt8);
160     }
161     throw malformed_input("Expected quantized dtype");
162   }(std::get<int64_t>(inputs[3]));
163 
164   ExprHandle e =
165       quant(tensorOrConstant(inputs[0], indices), dtype, qscale, qzero);
166 
167   BufPtr buf = alloc<Buf>(
168       "quantize_per_tensor",
169       ExprHandleVectorToExprVector(outputShape),
170       dtype,
171       nullptr,
172       std::nullopt,
173       qscale.node(),
174       qzero.node());
175   return Tensor(buf, vars, e.node());
176 }
177 
computeQuantizedAdd(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)178 Tensor computeQuantizedAdd(
179     const std::vector<ArgValue>& inputs,
180     const std::vector<ExprHandle>& outputShape,
181     const std::vector<ExprHandle>& outputStrides,
182     const std::optional<ScalarType>& outputType,
183     at::Device) {
184   const BufHandle& QA = std::get<BufHandle>(inputs[0]);
185   const BufHandle& QB = std::get<BufHandle>(inputs[1]);
186   auto qa_scale = ExprHandle(QA.node()->qscale());
187   auto qa_zero = ExprHandle(QA.node()->qzero());
188   auto qb_scale = ExprHandle(QB.node()->qscale());
189   auto qb_zero = ExprHandle(QB.node()->qzero());
190   ExprHandle out_qscale = DoubleImm::make(std::get<double>(inputs[2]));
191   ExprHandle out_qzero = LongImm::make(std::get<int64_t>(inputs[3]));
192   Dtype dequant_dtype = kFloat;
193   Dtype out_dtype = outputType ? Dtype(*outputType) : QA.dtype();
194   std::vector<VarPtr> vars;
195   std::vector<ExprHandle> indices;
196   for (const auto& os : outputShape) {
197     auto var = alloc<Var>("", os.node()->dtype());
198     vars.push_back(var);
199     indices.push_back(VarHandle(var));
200   }
201   auto lhs = tensorOrConstant(inputs[0], indices);
202   auto rhs = tensorOrConstant(inputs[1], indices);
203   ExprHandle exprHandle = quant(
204       dequant(lhs, dequant_dtype, qa_scale, qa_zero) +
205           dequant(rhs, dequant_dtype, qb_scale, qb_zero),
206       out_dtype,
207       out_qscale,
208       out_qzero);
209   BufPtr buf = alloc<Buf>(
210       "quantized_add",
211       ExprHandleVectorToExprVector(outputShape),
212       out_dtype,
213       nullptr,
214       isChannelsLast(QA) ? make_channels_last_strides(outputShape)
215                          : make_contiguous_strides(outputShape),
216       out_qscale.node(),
217       out_qzero.node());
218   return Tensor(buf, vars, exprHandle.node());
219 }
220 
computeQuantizePerTensorExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)221 Tensor computeQuantizePerTensorExternalCall(
222     const std::vector<ArgValue>& inputs,
223     const std::vector<ExprHandle>& outputShape,
224     const std::vector<ExprHandle>& outputStrides,
225     const std::optional<ScalarType>& outputType,
226     at::Device) {
227   const BufHandle& x = std::get<BufHandle>(inputs[0]);
228   const auto qscale = std::get<double>(inputs[1]);
229   const auto qzero = std::get<int64_t>(inputs[2]);
230   const auto qdtype = std::get<int64_t>(inputs[3]);
231 
232   const auto dtype = [](auto qdtype) {
233     if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
234       return Dtype(ScalarType::QInt8);
235     } else if (static_cast<int64_t>(ScalarType::QUInt8) == qdtype) {
236       return Dtype(ScalarType::QUInt8);
237     }
238     throw malformed_input("Expected quantized dtype");
239   }(qdtype);
240   auto ResultBuf = [&]() {
241     if (isChannelsLast(x)) {
242       return makeQBufHandleChannelsLast(
243           "quantize_per_tensor", outputShape, dtype, qscale, qzero);
244     }
245     return makeQBufHandleContiguous(
246         "quantize_per_tensor", outputShape, dtype, qscale, qzero);
247   }();
248   StmtPtr s = ExternalCall::make(
249       ResultBuf, "nnc_aten_quantize_per_tensor", {x}, {qscale, qzero, qdtype});
250   return Tensor(ResultBuf.node(), s);
251 }
252 
computeDequantizeExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)253 Tensor computeDequantizeExternalCall(
254     const std::vector<ArgValue>& inputs,
255     const std::vector<ExprHandle>& outputShape,
256     const std::vector<ExprHandle>& outputStrides,
257     const std::optional<ScalarType>& outputType,
258     at::Device) {
259   Dtype dtype = kFloat;
260   if (outputType) {
261     dtype = Dtype(*outputType);
262   }
263 
264   const BufHandle& qx = std::get<BufHandle>(inputs[0]);
265   const int64_t qdtype = (int64_t)immQDType(qx);
266 
267   BufHandle ResultBuf("dequantize", outputShape, dtype);
268   StmtPtr s = ExternalCall::make(
269       ResultBuf,
270       "nnc_aten_dequantize",
271       {qx},
272       {ExprHandle(IRSimplifier::simplify(qx.node()->qscale())),
273        ExprHandle(IRSimplifier::simplify(qx.node()->qzero())),
274        qdtype});
275   return Tensor(ResultBuf.node(), s);
276 }
277 
computeQuantizedConv2dPrepack(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)278 Tensor computeQuantizedConv2dPrepack(
279     const std::vector<ArgValue>& inputs,
280     const std::vector<ExprHandle>& outputShape,
281     const std::vector<ExprHandle>& outputStrides,
282     const std::optional<ScalarType>& outputType,
283     at::Device) {
284   Dtype dtype = kFloat;
285   if (outputType) {
286     dtype = Dtype(*outputType);
287   }
288 
289   BufHandle ResultBuf("quantized_conv2d_prepack", outputShape, dtype);
290   const BufHandle& qw = std::get<BufHandle>(inputs[0]);
291   const BufHandle& b = std::get<BufHandle>(inputs[1]);
292   auto strides = _pair_int(inputs[2]);
293   auto padding = _pair_int(inputs[3]);
294   auto dilation = _pair_int(inputs[4]);
295   auto groups = std::get<int64_t>(inputs[5]);
296   TORCH_INTERNAL_ASSERT(
297       qw.node()->qscale(),
298       buildErrorMessage(
299           "quantized_conv2d_prepack: Expects quantized weights, qscale is missing"));
300   TORCH_INTERNAL_ASSERT(
301       qw.node()->qzero(),
302       buildErrorMessage(
303           "quantized_conv2d_prepack: Expects quantized weights, qzero is missing"));
304   StmtPtr s = ExternalCall::make(
305       ResultBuf,
306       "nnc_aten_quantized_conv2d_prepack",
307       {qw, b},
308       {strides[0],
309        strides[1],
310        padding[0],
311        padding[1],
312        dilation[0],
313        dilation[1],
314        groups,
315        immQScale(qw),
316        immQZero(qw),
317        (int64_t)immQDType(qw)});
318   return Tensor(ResultBuf.node(), s);
319 }
320 
computeQuantizedConv1d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)321 Tensor computeQuantizedConv1d(
322     const std::vector<ArgValue>& inputs,
323     const std::vector<ExprHandle>& outputShape,
324     const std::vector<ExprHandle>& outputStrides,
325     const std::optional<ScalarType>& outputType,
326     at::Device device) {
327   const BufHandle& qx = std::get<BufHandle>(inputs[0]);
328   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
329   const auto out_qscale = std::get<double>(inputs[2]);
330   const auto out_qzero = std::get<int64_t>(inputs[3]);
331   // Change to dtype based on outputType when dtype propagation implemented
332   const auto out_qdtype = immQDType(qx);
333   auto ResultBuf = makeQBufHandleChannelsLast(
334       "quantized_conv1d",
335       outputShape,
336       Dtype(out_qdtype),
337       out_qscale,
338       out_qzero);
339   StmtPtr s = ExternalCall::make(
340       ResultBuf,
341       "nnc_aten_quantized_conv1d",
342       {qx, prepacked},
343       {immQScale(qx),
344        immQZero(qx),
345        (int64_t)immQDType(qx),
346        out_qscale,
347        out_qzero});
348   return Tensor(ResultBuf.node(), s);
349 }
350 
computeQuantizedConv2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)351 Tensor computeQuantizedConv2d(
352     const std::vector<ArgValue>& inputs,
353     const std::vector<ExprHandle>& outputShape,
354     const std::vector<ExprHandle>& outputStrides,
355     const std::optional<ScalarType>& outputType,
356     at::Device device) {
357   const BufHandle& qx = std::get<BufHandle>(inputs[0]);
358   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
359   const auto out_qscale = std::get<double>(inputs[2]);
360   const auto out_qzero = std::get<int64_t>(inputs[3]);
361   // Change to dtype based on outputType when dtype propagation implemented
362   const auto out_qdtype = immQDType(qx);
363   auto ResultBuf = makeQBufHandleChannelsLast(
364       "quantized_conv2d",
365       outputShape,
366       Dtype(out_qdtype),
367       out_qscale,
368       out_qzero);
369   StmtPtr s = ExternalCall::make(
370       ResultBuf,
371       "nnc_aten_quantized_conv2d",
372       {qx, prepacked},
373       {immQScale(qx),
374        immQZero(qx),
375        (int64_t)immQDType(qx),
376        out_qscale,
377        out_qzero});
378   return Tensor(ResultBuf.node(), s);
379 }
380 
computeQuantizedConv2dRelu(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)381 Tensor computeQuantizedConv2dRelu(
382     const std::vector<ArgValue>& inputs,
383     const std::vector<ExprHandle>& outputShape,
384     const std::vector<ExprHandle>& outputStrides,
385     const std::optional<ScalarType>& outputType,
386     at::Device device) {
387   const BufHandle& qx = std::get<BufHandle>(inputs[0]);
388   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
389   const auto out_qscale = std::get<double>(inputs[2]);
390   const auto out_qzero = std::get<int64_t>(inputs[3]);
391   // Change to dtype based on outputType when dtype propagation implemented
392   const auto out_qdtype = immQDType(qx);
393   auto ResultBuf = makeQBufHandleChannelsLast(
394       "quantized_conv2d_relu",
395       outputShape,
396       Dtype(out_qdtype),
397       out_qscale,
398       out_qzero);
399   StmtPtr s = ExternalCall::make(
400       ResultBuf,
401       "nnc_aten_quantized_conv2d_relu",
402       {qx, prepacked},
403       {immQScale(qx),
404        immQZero(qx),
405        (int64_t)immQDType(qx),
406        out_qscale,
407        out_qzero});
408   return Tensor(ResultBuf.node(), s);
409 }
410 
computeQuantizedLinear(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)411 Tensor computeQuantizedLinear(
412     const std::vector<ArgValue>& inputs,
413     const std::vector<ExprHandle>& outputShape,
414     const std::vector<ExprHandle>& outputStrides,
415     const std::optional<ScalarType>& outputType,
416     at::Device device) {
417   const BufHandle& qx = std::get<BufHandle>(inputs[0]);
418   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
419   const auto out_qscale = std::get<double>(inputs[2]);
420   const auto out_qzero = std::get<int64_t>(inputs[3]);
421   // Change to dtype based on outputType when dtype propagation implemented
422   const auto out_qdtype = immQDType(qx);
423   auto ResultBuf = makeQBufHandleContiguous(
424       "quantized_linear",
425       outputShape,
426       Dtype(out_qdtype),
427       out_qscale,
428       out_qzero);
429   StmtPtr s = ExternalCall::make(
430       ResultBuf,
431       "nnc_aten_quantized_linear",
432       {qx, prepacked},
433       {immQScale(qx),
434        immQZero(qx),
435        (int64_t)immQDType(qx),
436        out_qscale,
437        out_qzero});
438   return Tensor(ResultBuf.node(), s);
439 }
440 
computeQuantizedLinearRelu(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)441 Tensor computeQuantizedLinearRelu(
442     const std::vector<ArgValue>& inputs,
443     const std::vector<ExprHandle>& outputShape,
444     const std::vector<ExprHandle>& outputStrides,
445     const std::optional<ScalarType>& outputType,
446     at::Device device) {
447   const BufHandle& qx = std::get<BufHandle>(inputs[0]);
448   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
449   const auto out_qscale = std::get<double>(inputs[2]);
450   const auto out_qzero = std::get<int64_t>(inputs[3]);
451   // Change to dtype based on outputType when dtype propagation implemented
452   const auto out_qdtype = immQDType(qx);
453   auto ResultBuf = makeQBufHandleContiguous(
454       "quantized_linear_relu",
455       outputShape,
456       Dtype(out_qdtype),
457       out_qscale,
458       out_qzero);
459   StmtPtr s = ExternalCall::make(
460       ResultBuf,
461       "nnc_aten_quantized_linear_relu",
462       {qx, prepacked},
463       {immQScale(qx),
464        immQZero(qx),
465        (int64_t)immQDType(qx),
466        out_qscale,
467        out_qzero});
468   return Tensor(ResultBuf.node(), s);
469 }
470 
computeQuantizedAddExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)471 Tensor computeQuantizedAddExternalCall(
472     const std::vector<ArgValue>& inputs,
473     const std::vector<ExprHandle>& outputShape,
474     const std::vector<ExprHandle>& outputStrides,
475     const std::optional<ScalarType>& outputType,
476     at::Device device) {
477   const BufHandle& qa = std::get<BufHandle>(inputs[0]);
478   const BufHandle& qb = std::get<BufHandle>(inputs[1]);
479   const auto out_qscale = std::get<double>(inputs[2]);
480   const auto out_qzero = std::get<int64_t>(inputs[3]);
481   // Change to dtype based on outputType when dtype propagation implemented
482   const auto out_qdtype = immQDType(qa);
483   const bool isQAChannelsLast = isChannelsLast(qa);
484   const bool isQBChannelsLast = isChannelsLast(qb);
485   auto ResultBuf = (isQAChannelsLast || isQBChannelsLast)
486       ? makeQBufHandleChannelsLast(
487             "quantized_add",
488             outputShape,
489             Dtype(out_qdtype),
490             out_qscale,
491             out_qzero)
492       : makeQBufHandleContiguous(
493             "quantized_add",
494             outputShape,
495             Dtype(out_qdtype),
496             out_qscale,
497             out_qzero);
498   StmtPtr s = ExternalCall::make(
499       ResultBuf,
500       "nnc_aten_quantized_add",
501       {qa, qb},
502       {immQScale(qa),
503        immQZero(qa),
504        (int64_t)immQDType(qa),
505        immQScale(qb),
506        immQZero(qb),
507        (int64_t)immQDType(qb),
508        out_qscale,
509        out_qzero});
510   return Tensor(ResultBuf.node(), s);
511 }
512 
computeQuantizedMul(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)513 Tensor computeQuantizedMul(
514     const std::vector<ArgValue>& inputs,
515     const std::vector<ExprHandle>& outputShape,
516     const std::vector<ExprHandle>& outputStrides,
517     const std::optional<ScalarType>& outputType,
518     at::Device device) {
519   const BufHandle& qa = std::get<BufHandle>(inputs[0]);
520   const BufHandle& qb = std::get<BufHandle>(inputs[1]);
521   const auto out_qscale = std::get<double>(inputs[2]);
522   const auto out_qzero = std::get<int64_t>(inputs[3]);
523   // Change to dtype based on outputType when dtype propagation implemented
524   const auto out_qdtype = immQDType(qa);
525   auto ResultBuf = makeQBufHandleContiguous(
526       "quantized_mul", outputShape, Dtype(out_qdtype), out_qscale, out_qzero);
527   StmtPtr s = ExternalCall::make(
528       ResultBuf,
529       "nnc_aten_quantized_mul",
530       {qa, qb},
531       {immQScale(qa),
532        immQZero(qa),
533        (int64_t)immQDType(qa),
534        immQScale(qb),
535        immQZero(qb),
536        (int64_t)immQDType(qb),
537        out_qscale,
538        out_qzero});
539   return Tensor(ResultBuf.node(), s);
540 }
541 
computeQuantizedMulScalar(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)542 Tensor computeQuantizedMulScalar(
543     const std::vector<ArgValue>& inputs,
544     const std::vector<ExprHandle>& outputShape,
545     const std::vector<ExprHandle>& outputStrides,
546     // NOLINTNEXTLINE
547     const std::optional<ScalarType>& outputType,
548     at::Device device) {
549   const BufHandle& qa = std::get<BufHandle>(inputs[0]);
550   const auto scalar = std::get<double>(inputs[1]);
551   // Change to dtype based on outputType when dtype propagation implemented
552   const auto out_qdtype = immQDType(qa);
553   double scale1 = immQScale(qa);
554   auto ResultBuf = makeQBufHandleContiguous(
555       "quantized_mul_scalar",
556       outputShape,
557       Dtype(out_qdtype),
558       scale1 * scalar,
559       immQZero(qa));
560   StmtPtr s = ExternalCall::make(
561       ResultBuf,
562       "nnc_aten_quantized_mul_scalar",
563       {qa},
564       {scale1, immQZero(qa), (int64_t)immQDType(qa), scalar});
565   return Tensor(ResultBuf.node(), s);
566 }
567 
computeQuantizedRelu(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)568 Tensor computeQuantizedRelu(
569     const std::vector<ArgValue>& inputs,
570     const std::vector<ExprHandle>& outputShape,
571     const std::vector<ExprHandle>& outputStrides,
572     const std::optional<ScalarType>& outputType,
573     at::Device device) {
574   const BufHandle& qa = std::get<BufHandle>(inputs[0]);
575   const auto out_qdtype = immQDType(qa);
576   const bool isQAChannelsLast = isChannelsLast(qa);
577   auto ResultBuf = isQAChannelsLast ? makeQBufHandleChannelsLast(
578                                           "quantized_relu",
579                                           outputShape,
580                                           Dtype(out_qdtype),
581                                           immQScale(qa),
582                                           immQZero(qa))
583                                     : makeQBufHandleContiguous(
584                                           "quantized_relu",
585                                           outputShape,
586                                           Dtype(out_qdtype),
587                                           immQScale(qa),
588                                           immQZero(qa));
589   StmtPtr s = ExternalCall::make(
590       ResultBuf,
591       "nnc_aten_quantized_relu",
592       {qa},
593       {immQScale(qa), immQZero(qa), (int64_t)immQDType(qa)});
594   return Tensor(ResultBuf.node(), s);
595 }
596 
computeQuantizedCat(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)597 Tensor computeQuantizedCat(
598     const std::vector<ArgValue>& inputs,
599     const std::vector<ExprHandle>& outputShape,
600     const std::vector<ExprHandle>& outputStrides,
601     // NOLINTNEXTLINE
602     const std::optional<ScalarType>& outputType,
603     // NOLINTNEXTLINE
604     at::Device device) {
605   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
606   auto inputList = std::get<BufList>(inputs[0]);
607   auto argDim = std::get<int64_t>(inputs[1]);
608   auto n = inputList.size();
609   // TODO: handle optional out_qscale, out_qzero
610   const auto out_qscale = std::get<double>(inputs[2]);
611   const auto out_qzero = std::get<int64_t>(inputs[3]);
612 
613   std::vector<BufHandle> args;
614   std::vector<ExprHandle> extra_args;
615   for (const auto i : c10::irange(n)) {
616     const BufHandle& bh = inputList[i];
617     args.emplace_back(bh);
618     extra_args.emplace_back(immQScale(bh));
619     extra_args.emplace_back(immQZero(bh));
620     extra_args.emplace_back((int64_t)immQDType(bh));
621   }
622   extra_args.emplace_back(argDim);
623   extra_args.emplace_back(out_qscale);
624   extra_args.emplace_back(out_qzero);
625   auto ResultBuf = makeQBufHandleContiguous(
626       "quantized_cat",
627       outputShape,
628       Dtype(immQDType(inputList[0])),
629       out_qscale,
630       out_qzero);
631   StmtPtr s =
632       ExternalCall::make(ResultBuf, "nnc_aten_quantized_cat", args, extra_args);
633   return Tensor(ResultBuf.node(), s);
634 }
635 
computeDequantize(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)636 Tensor computeDequantize(
637     const std::vector<ArgValue>& inputs,
638     const std::vector<ExprHandle>& outputShape,
639     const std::vector<ExprHandle>& outputStrides,
640     const std::optional<ScalarType>& outputType,
641     at::Device) {
642   Dtype dtype = kFloat;
643   if (outputType) {
644     dtype = Dtype(*outputType);
645   }
646   auto qx = std::get<BufHandle>(inputs[0]);
647   TORCH_INTERNAL_ASSERT(
648       qx.node()->qscale(),
649       buildErrorMessage("Missing quantized scale for dequantize"));
650   TORCH_INTERNAL_ASSERT(
651       qx.node()->qzero(),
652       buildErrorMessage("Missing quantized zero point for dequantize"));
653   auto qscale = ExprHandle(qx.node()->qscale());
654   auto qzero = ExprHandle(qx.node()->qzero());
655   std::vector<VarPtr> vars;
656   std::vector<ExprHandle> indices;
657   for (const auto& os : outputShape) {
658     auto var = alloc<Var>("", os.node()->dtype());
659     vars.push_back(var);
660     indices.push_back(VarHandle(var));
661   }
662   auto y = dequant(tensorOrConstant(inputs[0], indices), dtype, qscale, qzero);
663   BufPtr buf = alloc<Buf>(
664       "dequantize", ExprHandleVectorToExprVector(outputShape), dtype);
665   return Tensor(buf, vars, y.node());
666 }
667 
computeUpsampleNearest2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)668 Tensor computeUpsampleNearest2d(
669     const std::vector<ArgValue>& inputs,
670     const std::vector<ExprHandle>& outputShape,
671     const std::vector<ExprHandle>& outputStrides,
672     const std::optional<ScalarType>& outputType,
673     at::Device) {
674   auto A = std::get<BufHandle>(inputs[0]);
675   const auto& output_height = outputShape[2];
676   const auto& output_width = outputShape[3];
677   auto input_height = ExprHandle(A.dim(2));
678   auto input_width = ExprHandle(A.dim(3));
679 
680   std::vector<VarHandle> args = create_index_vars(outputShape);
681   // Handle separately when scale is specified? as in 'scalar_t
682   // compute_scales_value' in UpSample.h
683   auto scale_h =
684       promoteToDtype(input_height, ScalarType::Double) / output_height;
685   auto scale_w = promoteToDtype(input_width, ScalarType::Double) / output_width;
686   // TODO: will repetitive if in idx calculation will be taken out of the loop?
687   auto compute_nearest_idx = [](const ExprHandle& scale,
688                                 const ExprHandle& dst_index,
689                                 const ExprHandle& input_size) {
690     return Min::make(
691         promoteToDtype(floor(dst_index * scale), ScalarType::Long),
692         input_size - 1,
693         true);
694   };
695   auto body_func = [&](std::vector<VarHandle> axes) {
696     std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
697     newAxes[2] = compute_nearest_idx(scale_h, axes[2], input_height);
698     newAxes[3] = compute_nearest_idx(scale_w, axes[3], input_width);
699     return A.load(newAxes);
700   };
701   auto e = body_func(args);
702   auto strides = isChannelsLast(A) ? make_channels_last_strides(outputShape)
703                                    : make_contiguous_strides(outputShape);
704   BufHandle buf = Buf::make(
705       "upsample_nearest2d",
706       outputShape,
707       Dtype(*outputType),
708       std::nullopt, // initializer
709       fmap(strides, [&](const ExprPtr& stride) { return ExprHandle(stride); }),
710       ExprHandle(A.node()->qscale()),
711       ExprHandle(A.node()->qzero()));
712   return Tensor(buf, args, e);
713 }
714 
computeUpsampleNearest2dExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)715 Tensor computeUpsampleNearest2dExternalCall(
716     const std::vector<ArgValue>& inputs,
717     const std::vector<ExprHandle>& outputShape,
718     const std::vector<ExprHandle>& outputStrides,
719     const std::optional<ScalarType>& outputType,
720     at::Device) {
721   Dtype dtype = kFloat;
722   if (outputType) {
723     dtype = Dtype(*outputType);
724   }
725   int64_t output_size_h = -1;
726   int64_t output_size_w = -1;
727   if (auto output_sizes = std::get_if<IntList>(&inputs[1])) {
728     output_size_h = (*output_sizes)[0];
729     output_size_w = (*output_sizes)[1];
730   }
731 
732   double scale_factor_h = -1.f;
733   double scale_factor_w = -1.f;
734   if (auto scale_factors = std::get_if<DoubleList>(&inputs[2])) {
735     scale_factor_h = (*scale_factors)[0];
736     scale_factor_w = (*scale_factors)[1];
737   }
738   const BufHandle& x = std::get<BufHandle>(inputs[0]);
739   double qx_qscale = -1.f;
740   int64_t qx_qzero = -1l;
741   int64_t qx_qdtype = -1l;
742   if (isQuantized(x)) {
743     qx_qscale = immQScale(x);
744     qx_qzero = immQZero(x);
745     qx_qdtype = (int64_t)immQDType(x);
746   }
747 
748   BufHandle ResultBuf = [&]() {
749     if (isQuantized(x)) {
750       return makeQBufHandleChannelsLast(
751           "upsample_nearest2d",
752           outputShape,
753           Dtype(immQDType(x)),
754           qx_qscale,
755           qx_qzero);
756     }
757     return BufHandle("upsample_nearest2d", outputShape, dtype);
758   }();
759 
760   StmtPtr s = ExternalCall::make(
761       ResultBuf,
762       "nnc_aten_upsample_nearest2d",
763       {x},
764       {qx_qscale,
765        qx_qzero,
766        qx_qdtype,
767        output_size_h,
768        output_size_w,
769        scale_factor_h,
770        scale_factor_w});
771   return Tensor(ResultBuf.node(), s);
772 }
773 
computeQuantizedSigmoidExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)774 Tensor computeQuantizedSigmoidExternalCall(
775     const std::vector<ArgValue>& inputs,
776     const std::vector<ExprHandle>& outputShape,
777     const std::vector<ExprHandle>& outputStrides,
778     const std::optional<ScalarType>& outputType,
779     at::Device) {
780   const BufHandle& qx = std::get<BufHandle>(inputs[0]);
781 
782   const auto out_qdtype = immQDType(qx);
783   const double out_qscale = 1.0f / 256.0f;
784   const int64_t out_qzero = (out_qdtype == ScalarType::QInt8) ? -128 : 0;
785 
786   auto ResultBuf = isChannelsLast(qx) ? makeQBufHandleChannelsLast(
787                                             "quantized_sigmoid",
788                                             outputShape,
789                                             Dtype(out_qdtype),
790                                             out_qscale,
791                                             out_qzero)
792                                       : makeQBufHandleContiguous(
793                                             "quantized_sigmoid",
794                                             outputShape,
795                                             Dtype(out_qdtype),
796                                             out_qscale,
797                                             out_qzero);
798   StmtPtr s = ExternalCall::make(
799       ResultBuf,
800       "nnc_aten_quantized_sigmoid",
801       {qx},
802       {immQScale(qx),
803        immQZero(qx),
804        (int64_t)immQDType(qx),
805        out_qscale,
806        out_qzero});
807   return Tensor(ResultBuf.node(), s);
808 }
809 
810 } // namespace torch::jit::tensorexpr
811