1 #include <c10/core/ScalarType.h>
2 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
4 #include <torch/csrc/jit/tensorexpr/operators/pointwise.h>
5 #include <torch/csrc/jit/tensorexpr/operators/quantization.h>
6
7 using namespace torch::jit::tensorexpr;
8
9 namespace torch::jit::tensorexpr {
10 namespace {
_pair_int(ArgValue v)11 std::vector<int64_t> _pair_int(ArgValue v) {
12 if (auto t = std::get_if<IntList>(&v)) {
13 return {(*t)[0], (*t)[1]};
14 }
15 auto i = std::get<int64_t>(v);
16 return {i, i};
17 }
18 } // namespace
19
immQScale(const BufHandle & qx)20 double immQScale(const BufHandle& qx) {
21 TORCH_INTERNAL_ASSERT(
22 qx.node()->qscale(), buildErrorMessage("Expects BufHandle with qscale"));
23 return to<DoubleImm>(IRSimplifier::simplify(qx.node()->qscale()))->value();
24 }
25
immQZero(const BufHandle & qx)26 int64_t immQZero(const BufHandle& qx) {
27 TORCH_INTERNAL_ASSERT(
28 qx.node()->qzero(), buildErrorMessage("Expects BufHandle with qzero"));
29 return to<LongImm>(IRSimplifier::simplify(qx.node()->qzero()))->value();
30 }
31
immQDType(const BufHandle & qx)32 ScalarType immQDType(const BufHandle& qx) {
33 return qx.dtype().scalar_type();
34 }
35
isQuantized(const BufHandle & qx)36 bool isQuantized(const BufHandle& qx) {
37 return qx.node()->qscale() && qx.node()->qzero();
38 }
39
makeQBufHandleChannelsLast(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const ExprPtr & qscale,const ExprPtr & qzero)40 static BufHandle makeQBufHandleChannelsLast(
41 const std::string& name,
42 const std::vector<ExprHandle>& dims,
43 Dtype dtype,
44 const ExprPtr& qscale,
45 const ExprPtr& qzero) {
46 BufHandle ResultBuf(name, dims, dtype);
47 ResultBuf.node()->set_qscale(qscale);
48 ResultBuf.node()->set_qzero(qzero);
49 ResultBuf.node()->set_strides(make_channels_last_strides(dims));
50 return ResultBuf;
51 }
52
makeQBufHandleChannelsLast(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const double qscale,const int64_t qzero)53 static BufHandle makeQBufHandleChannelsLast(
54 const std::string& name,
55 const std::vector<ExprHandle>& dims,
56 Dtype dtype,
57 const double qscale,
58 const int64_t qzero) {
59 return makeQBufHandleChannelsLast(
60 name,
61 dims,
62 dtype,
63 DoubleImm::make(qscale).node(),
64 LongImm::make(qzero).node());
65 }
66
makeQBufHandleContiguous(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const ExprPtr & qscale,const ExprPtr & qzero)67 static BufHandle makeQBufHandleContiguous(
68 const std::string& name,
69 const std::vector<ExprHandle>& dims,
70 Dtype dtype,
71 const ExprPtr& qscale,
72 const ExprPtr& qzero) {
73 BufHandle ResultBuf(name, dims, dtype);
74 ResultBuf.node()->set_qscale(qscale);
75 ResultBuf.node()->set_qzero(qzero);
76 ResultBuf.node()->set_strides(make_contiguous_strides(dims));
77 return ResultBuf;
78 }
79
makeQBufHandleContiguous(const std::string & name,const std::vector<ExprHandle> & dims,Dtype dtype,const double qscale,const int64_t qzero)80 static BufHandle makeQBufHandleContiguous(
81 const std::string& name,
82 const std::vector<ExprHandle>& dims,
83 Dtype dtype,
84 const double qscale,
85 const int64_t qzero) {
86 return makeQBufHandleContiguous(
87 name,
88 dims,
89 dtype,
90 DoubleImm::make(qscale).node(),
91 LongImm::make(qzero).node());
92 }
93
isChannelsLast(const BufHandle & buf)94 static bool isChannelsLast(const BufHandle& buf) {
95 const auto& strides = buf.node()->strides();
96 const auto& dims = buf.node()->dims();
97 const auto rank = dims.size();
98 if (rank < 3) {
99 return false;
100 }
101 auto dimsC = to<LongImm>(IRSimplifier::simplify(dims[1]))->value();
102 auto stridesC = to<LongImm>(IRSimplifier::simplify(strides[1]))->value();
103 auto stridesLast =
104 to<LongImm>(IRSimplifier::simplify(strides[rank - 1]))->value();
105
106 return ((stridesLast == dimsC) && (stridesC == 1));
107 }
108
quant(const ExprHandle & x,Dtype out_dtype,ExprHandle qscale,ExprHandle qzero)109 static ExprHandle quant(
110 const ExprHandle& x,
111 Dtype out_dtype,
112 ExprHandle qscale,
113 ExprHandle qzero) {
114 auto promoted_qscale =
115 promoteToDtype(std::move(qscale), x.dtype().scalar_type());
116 auto promoted_qzero =
117 promoteToDtype(std::move(qzero), x.dtype().scalar_type());
118 return promoteToDtype(
119 x / promoted_qscale + promoted_qzero + FloatImm::make(0.5f),
120 out_dtype.scalar_type());
121 }
122
dequant(ExprHandle qx,Dtype out_dtype,ExprHandle qscale,ExprHandle qzero)123 static ExprHandle dequant(
124 ExprHandle qx,
125 Dtype out_dtype,
126 ExprHandle qscale,
127 ExprHandle qzero) {
128 auto qx_promoted = promoteToDtype(std::move(qx), out_dtype.scalar_type());
129 auto qscale_promoted =
130 promoteToDtype(std::move(qscale), out_dtype.scalar_type());
131 auto qzero_promoted =
132 promoteToDtype(std::move(qzero), out_dtype.scalar_type());
133 return promoteToDtype(
134 (qx_promoted - qzero_promoted) * qscale_promoted,
135 out_dtype.scalar_type());
136 }
137
computeQuantizePerTensor(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> &,at::Device)138 Tensor computeQuantizePerTensor(
139 const std::vector<ArgValue>& inputs,
140 const std::vector<ExprHandle>& outputShape,
141 const std::vector<ExprHandle>& outputStrides,
142 const std::optional<ScalarType>&,
143 at::Device) {
144 std::vector<VarPtr> vars;
145 std::vector<ExprHandle> indices;
146 for (const auto& os : outputShape) {
147 auto var = alloc<Var>("", os.node()->dtype());
148 vars.push_back(var);
149 indices.push_back(VarHandle(var));
150 }
151
152 ExprHandle qscale = constant(inputs[1]);
153 ExprHandle qzero = constant(inputs[2]);
154
155 const auto dtype = [](auto qdtype) {
156 if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
157 return Dtype(ScalarType::QInt8);
158 } else if (static_cast<int64_t>(ScalarType::QUInt8) == qdtype) {
159 return Dtype(ScalarType::QUInt8);
160 }
161 throw malformed_input("Expected quantized dtype");
162 }(std::get<int64_t>(inputs[3]));
163
164 ExprHandle e =
165 quant(tensorOrConstant(inputs[0], indices), dtype, qscale, qzero);
166
167 BufPtr buf = alloc<Buf>(
168 "quantize_per_tensor",
169 ExprHandleVectorToExprVector(outputShape),
170 dtype,
171 nullptr,
172 std::nullopt,
173 qscale.node(),
174 qzero.node());
175 return Tensor(buf, vars, e.node());
176 }
177
computeQuantizedAdd(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)178 Tensor computeQuantizedAdd(
179 const std::vector<ArgValue>& inputs,
180 const std::vector<ExprHandle>& outputShape,
181 const std::vector<ExprHandle>& outputStrides,
182 const std::optional<ScalarType>& outputType,
183 at::Device) {
184 const BufHandle& QA = std::get<BufHandle>(inputs[0]);
185 const BufHandle& QB = std::get<BufHandle>(inputs[1]);
186 auto qa_scale = ExprHandle(QA.node()->qscale());
187 auto qa_zero = ExprHandle(QA.node()->qzero());
188 auto qb_scale = ExprHandle(QB.node()->qscale());
189 auto qb_zero = ExprHandle(QB.node()->qzero());
190 ExprHandle out_qscale = DoubleImm::make(std::get<double>(inputs[2]));
191 ExprHandle out_qzero = LongImm::make(std::get<int64_t>(inputs[3]));
192 Dtype dequant_dtype = kFloat;
193 Dtype out_dtype = outputType ? Dtype(*outputType) : QA.dtype();
194 std::vector<VarPtr> vars;
195 std::vector<ExprHandle> indices;
196 for (const auto& os : outputShape) {
197 auto var = alloc<Var>("", os.node()->dtype());
198 vars.push_back(var);
199 indices.push_back(VarHandle(var));
200 }
201 auto lhs = tensorOrConstant(inputs[0], indices);
202 auto rhs = tensorOrConstant(inputs[1], indices);
203 ExprHandle exprHandle = quant(
204 dequant(lhs, dequant_dtype, qa_scale, qa_zero) +
205 dequant(rhs, dequant_dtype, qb_scale, qb_zero),
206 out_dtype,
207 out_qscale,
208 out_qzero);
209 BufPtr buf = alloc<Buf>(
210 "quantized_add",
211 ExprHandleVectorToExprVector(outputShape),
212 out_dtype,
213 nullptr,
214 isChannelsLast(QA) ? make_channels_last_strides(outputShape)
215 : make_contiguous_strides(outputShape),
216 out_qscale.node(),
217 out_qzero.node());
218 return Tensor(buf, vars, exprHandle.node());
219 }
220
computeQuantizePerTensorExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)221 Tensor computeQuantizePerTensorExternalCall(
222 const std::vector<ArgValue>& inputs,
223 const std::vector<ExprHandle>& outputShape,
224 const std::vector<ExprHandle>& outputStrides,
225 const std::optional<ScalarType>& outputType,
226 at::Device) {
227 const BufHandle& x = std::get<BufHandle>(inputs[0]);
228 const auto qscale = std::get<double>(inputs[1]);
229 const auto qzero = std::get<int64_t>(inputs[2]);
230 const auto qdtype = std::get<int64_t>(inputs[3]);
231
232 const auto dtype = [](auto qdtype) {
233 if (static_cast<int64_t>(ScalarType::QInt8) == qdtype) {
234 return Dtype(ScalarType::QInt8);
235 } else if (static_cast<int64_t>(ScalarType::QUInt8) == qdtype) {
236 return Dtype(ScalarType::QUInt8);
237 }
238 throw malformed_input("Expected quantized dtype");
239 }(qdtype);
240 auto ResultBuf = [&]() {
241 if (isChannelsLast(x)) {
242 return makeQBufHandleChannelsLast(
243 "quantize_per_tensor", outputShape, dtype, qscale, qzero);
244 }
245 return makeQBufHandleContiguous(
246 "quantize_per_tensor", outputShape, dtype, qscale, qzero);
247 }();
248 StmtPtr s = ExternalCall::make(
249 ResultBuf, "nnc_aten_quantize_per_tensor", {x}, {qscale, qzero, qdtype});
250 return Tensor(ResultBuf.node(), s);
251 }
252
computeDequantizeExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)253 Tensor computeDequantizeExternalCall(
254 const std::vector<ArgValue>& inputs,
255 const std::vector<ExprHandle>& outputShape,
256 const std::vector<ExprHandle>& outputStrides,
257 const std::optional<ScalarType>& outputType,
258 at::Device) {
259 Dtype dtype = kFloat;
260 if (outputType) {
261 dtype = Dtype(*outputType);
262 }
263
264 const BufHandle& qx = std::get<BufHandle>(inputs[0]);
265 const int64_t qdtype = (int64_t)immQDType(qx);
266
267 BufHandle ResultBuf("dequantize", outputShape, dtype);
268 StmtPtr s = ExternalCall::make(
269 ResultBuf,
270 "nnc_aten_dequantize",
271 {qx},
272 {ExprHandle(IRSimplifier::simplify(qx.node()->qscale())),
273 ExprHandle(IRSimplifier::simplify(qx.node()->qzero())),
274 qdtype});
275 return Tensor(ResultBuf.node(), s);
276 }
277
computeQuantizedConv2dPrepack(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)278 Tensor computeQuantizedConv2dPrepack(
279 const std::vector<ArgValue>& inputs,
280 const std::vector<ExprHandle>& outputShape,
281 const std::vector<ExprHandle>& outputStrides,
282 const std::optional<ScalarType>& outputType,
283 at::Device) {
284 Dtype dtype = kFloat;
285 if (outputType) {
286 dtype = Dtype(*outputType);
287 }
288
289 BufHandle ResultBuf("quantized_conv2d_prepack", outputShape, dtype);
290 const BufHandle& qw = std::get<BufHandle>(inputs[0]);
291 const BufHandle& b = std::get<BufHandle>(inputs[1]);
292 auto strides = _pair_int(inputs[2]);
293 auto padding = _pair_int(inputs[3]);
294 auto dilation = _pair_int(inputs[4]);
295 auto groups = std::get<int64_t>(inputs[5]);
296 TORCH_INTERNAL_ASSERT(
297 qw.node()->qscale(),
298 buildErrorMessage(
299 "quantized_conv2d_prepack: Expects quantized weights, qscale is missing"));
300 TORCH_INTERNAL_ASSERT(
301 qw.node()->qzero(),
302 buildErrorMessage(
303 "quantized_conv2d_prepack: Expects quantized weights, qzero is missing"));
304 StmtPtr s = ExternalCall::make(
305 ResultBuf,
306 "nnc_aten_quantized_conv2d_prepack",
307 {qw, b},
308 {strides[0],
309 strides[1],
310 padding[0],
311 padding[1],
312 dilation[0],
313 dilation[1],
314 groups,
315 immQScale(qw),
316 immQZero(qw),
317 (int64_t)immQDType(qw)});
318 return Tensor(ResultBuf.node(), s);
319 }
320
computeQuantizedConv1d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)321 Tensor computeQuantizedConv1d(
322 const std::vector<ArgValue>& inputs,
323 const std::vector<ExprHandle>& outputShape,
324 const std::vector<ExprHandle>& outputStrides,
325 const std::optional<ScalarType>& outputType,
326 at::Device device) {
327 const BufHandle& qx = std::get<BufHandle>(inputs[0]);
328 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
329 const auto out_qscale = std::get<double>(inputs[2]);
330 const auto out_qzero = std::get<int64_t>(inputs[3]);
331 // Change to dtype based on outputType when dtype propagation implemented
332 const auto out_qdtype = immQDType(qx);
333 auto ResultBuf = makeQBufHandleChannelsLast(
334 "quantized_conv1d",
335 outputShape,
336 Dtype(out_qdtype),
337 out_qscale,
338 out_qzero);
339 StmtPtr s = ExternalCall::make(
340 ResultBuf,
341 "nnc_aten_quantized_conv1d",
342 {qx, prepacked},
343 {immQScale(qx),
344 immQZero(qx),
345 (int64_t)immQDType(qx),
346 out_qscale,
347 out_qzero});
348 return Tensor(ResultBuf.node(), s);
349 }
350
computeQuantizedConv2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)351 Tensor computeQuantizedConv2d(
352 const std::vector<ArgValue>& inputs,
353 const std::vector<ExprHandle>& outputShape,
354 const std::vector<ExprHandle>& outputStrides,
355 const std::optional<ScalarType>& outputType,
356 at::Device device) {
357 const BufHandle& qx = std::get<BufHandle>(inputs[0]);
358 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
359 const auto out_qscale = std::get<double>(inputs[2]);
360 const auto out_qzero = std::get<int64_t>(inputs[3]);
361 // Change to dtype based on outputType when dtype propagation implemented
362 const auto out_qdtype = immQDType(qx);
363 auto ResultBuf = makeQBufHandleChannelsLast(
364 "quantized_conv2d",
365 outputShape,
366 Dtype(out_qdtype),
367 out_qscale,
368 out_qzero);
369 StmtPtr s = ExternalCall::make(
370 ResultBuf,
371 "nnc_aten_quantized_conv2d",
372 {qx, prepacked},
373 {immQScale(qx),
374 immQZero(qx),
375 (int64_t)immQDType(qx),
376 out_qscale,
377 out_qzero});
378 return Tensor(ResultBuf.node(), s);
379 }
380
computeQuantizedConv2dRelu(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)381 Tensor computeQuantizedConv2dRelu(
382 const std::vector<ArgValue>& inputs,
383 const std::vector<ExprHandle>& outputShape,
384 const std::vector<ExprHandle>& outputStrides,
385 const std::optional<ScalarType>& outputType,
386 at::Device device) {
387 const BufHandle& qx = std::get<BufHandle>(inputs[0]);
388 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
389 const auto out_qscale = std::get<double>(inputs[2]);
390 const auto out_qzero = std::get<int64_t>(inputs[3]);
391 // Change to dtype based on outputType when dtype propagation implemented
392 const auto out_qdtype = immQDType(qx);
393 auto ResultBuf = makeQBufHandleChannelsLast(
394 "quantized_conv2d_relu",
395 outputShape,
396 Dtype(out_qdtype),
397 out_qscale,
398 out_qzero);
399 StmtPtr s = ExternalCall::make(
400 ResultBuf,
401 "nnc_aten_quantized_conv2d_relu",
402 {qx, prepacked},
403 {immQScale(qx),
404 immQZero(qx),
405 (int64_t)immQDType(qx),
406 out_qscale,
407 out_qzero});
408 return Tensor(ResultBuf.node(), s);
409 }
410
computeQuantizedLinear(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)411 Tensor computeQuantizedLinear(
412 const std::vector<ArgValue>& inputs,
413 const std::vector<ExprHandle>& outputShape,
414 const std::vector<ExprHandle>& outputStrides,
415 const std::optional<ScalarType>& outputType,
416 at::Device device) {
417 const BufHandle& qx = std::get<BufHandle>(inputs[0]);
418 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
419 const auto out_qscale = std::get<double>(inputs[2]);
420 const auto out_qzero = std::get<int64_t>(inputs[3]);
421 // Change to dtype based on outputType when dtype propagation implemented
422 const auto out_qdtype = immQDType(qx);
423 auto ResultBuf = makeQBufHandleContiguous(
424 "quantized_linear",
425 outputShape,
426 Dtype(out_qdtype),
427 out_qscale,
428 out_qzero);
429 StmtPtr s = ExternalCall::make(
430 ResultBuf,
431 "nnc_aten_quantized_linear",
432 {qx, prepacked},
433 {immQScale(qx),
434 immQZero(qx),
435 (int64_t)immQDType(qx),
436 out_qscale,
437 out_qzero});
438 return Tensor(ResultBuf.node(), s);
439 }
440
computeQuantizedLinearRelu(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)441 Tensor computeQuantizedLinearRelu(
442 const std::vector<ArgValue>& inputs,
443 const std::vector<ExprHandle>& outputShape,
444 const std::vector<ExprHandle>& outputStrides,
445 const std::optional<ScalarType>& outputType,
446 at::Device device) {
447 const BufHandle& qx = std::get<BufHandle>(inputs[0]);
448 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
449 const auto out_qscale = std::get<double>(inputs[2]);
450 const auto out_qzero = std::get<int64_t>(inputs[3]);
451 // Change to dtype based on outputType when dtype propagation implemented
452 const auto out_qdtype = immQDType(qx);
453 auto ResultBuf = makeQBufHandleContiguous(
454 "quantized_linear_relu",
455 outputShape,
456 Dtype(out_qdtype),
457 out_qscale,
458 out_qzero);
459 StmtPtr s = ExternalCall::make(
460 ResultBuf,
461 "nnc_aten_quantized_linear_relu",
462 {qx, prepacked},
463 {immQScale(qx),
464 immQZero(qx),
465 (int64_t)immQDType(qx),
466 out_qscale,
467 out_qzero});
468 return Tensor(ResultBuf.node(), s);
469 }
470
computeQuantizedAddExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)471 Tensor computeQuantizedAddExternalCall(
472 const std::vector<ArgValue>& inputs,
473 const std::vector<ExprHandle>& outputShape,
474 const std::vector<ExprHandle>& outputStrides,
475 const std::optional<ScalarType>& outputType,
476 at::Device device) {
477 const BufHandle& qa = std::get<BufHandle>(inputs[0]);
478 const BufHandle& qb = std::get<BufHandle>(inputs[1]);
479 const auto out_qscale = std::get<double>(inputs[2]);
480 const auto out_qzero = std::get<int64_t>(inputs[3]);
481 // Change to dtype based on outputType when dtype propagation implemented
482 const auto out_qdtype = immQDType(qa);
483 const bool isQAChannelsLast = isChannelsLast(qa);
484 const bool isQBChannelsLast = isChannelsLast(qb);
485 auto ResultBuf = (isQAChannelsLast || isQBChannelsLast)
486 ? makeQBufHandleChannelsLast(
487 "quantized_add",
488 outputShape,
489 Dtype(out_qdtype),
490 out_qscale,
491 out_qzero)
492 : makeQBufHandleContiguous(
493 "quantized_add",
494 outputShape,
495 Dtype(out_qdtype),
496 out_qscale,
497 out_qzero);
498 StmtPtr s = ExternalCall::make(
499 ResultBuf,
500 "nnc_aten_quantized_add",
501 {qa, qb},
502 {immQScale(qa),
503 immQZero(qa),
504 (int64_t)immQDType(qa),
505 immQScale(qb),
506 immQZero(qb),
507 (int64_t)immQDType(qb),
508 out_qscale,
509 out_qzero});
510 return Tensor(ResultBuf.node(), s);
511 }
512
computeQuantizedMul(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)513 Tensor computeQuantizedMul(
514 const std::vector<ArgValue>& inputs,
515 const std::vector<ExprHandle>& outputShape,
516 const std::vector<ExprHandle>& outputStrides,
517 const std::optional<ScalarType>& outputType,
518 at::Device device) {
519 const BufHandle& qa = std::get<BufHandle>(inputs[0]);
520 const BufHandle& qb = std::get<BufHandle>(inputs[1]);
521 const auto out_qscale = std::get<double>(inputs[2]);
522 const auto out_qzero = std::get<int64_t>(inputs[3]);
523 // Change to dtype based on outputType when dtype propagation implemented
524 const auto out_qdtype = immQDType(qa);
525 auto ResultBuf = makeQBufHandleContiguous(
526 "quantized_mul", outputShape, Dtype(out_qdtype), out_qscale, out_qzero);
527 StmtPtr s = ExternalCall::make(
528 ResultBuf,
529 "nnc_aten_quantized_mul",
530 {qa, qb},
531 {immQScale(qa),
532 immQZero(qa),
533 (int64_t)immQDType(qa),
534 immQScale(qb),
535 immQZero(qb),
536 (int64_t)immQDType(qb),
537 out_qscale,
538 out_qzero});
539 return Tensor(ResultBuf.node(), s);
540 }
541
computeQuantizedMulScalar(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)542 Tensor computeQuantizedMulScalar(
543 const std::vector<ArgValue>& inputs,
544 const std::vector<ExprHandle>& outputShape,
545 const std::vector<ExprHandle>& outputStrides,
546 // NOLINTNEXTLINE
547 const std::optional<ScalarType>& outputType,
548 at::Device device) {
549 const BufHandle& qa = std::get<BufHandle>(inputs[0]);
550 const auto scalar = std::get<double>(inputs[1]);
551 // Change to dtype based on outputType when dtype propagation implemented
552 const auto out_qdtype = immQDType(qa);
553 double scale1 = immQScale(qa);
554 auto ResultBuf = makeQBufHandleContiguous(
555 "quantized_mul_scalar",
556 outputShape,
557 Dtype(out_qdtype),
558 scale1 * scalar,
559 immQZero(qa));
560 StmtPtr s = ExternalCall::make(
561 ResultBuf,
562 "nnc_aten_quantized_mul_scalar",
563 {qa},
564 {scale1, immQZero(qa), (int64_t)immQDType(qa), scalar});
565 return Tensor(ResultBuf.node(), s);
566 }
567
computeQuantizedRelu(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)568 Tensor computeQuantizedRelu(
569 const std::vector<ArgValue>& inputs,
570 const std::vector<ExprHandle>& outputShape,
571 const std::vector<ExprHandle>& outputStrides,
572 const std::optional<ScalarType>& outputType,
573 at::Device device) {
574 const BufHandle& qa = std::get<BufHandle>(inputs[0]);
575 const auto out_qdtype = immQDType(qa);
576 const bool isQAChannelsLast = isChannelsLast(qa);
577 auto ResultBuf = isQAChannelsLast ? makeQBufHandleChannelsLast(
578 "quantized_relu",
579 outputShape,
580 Dtype(out_qdtype),
581 immQScale(qa),
582 immQZero(qa))
583 : makeQBufHandleContiguous(
584 "quantized_relu",
585 outputShape,
586 Dtype(out_qdtype),
587 immQScale(qa),
588 immQZero(qa));
589 StmtPtr s = ExternalCall::make(
590 ResultBuf,
591 "nnc_aten_quantized_relu",
592 {qa},
593 {immQScale(qa), immQZero(qa), (int64_t)immQDType(qa)});
594 return Tensor(ResultBuf.node(), s);
595 }
596
computeQuantizedCat(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)597 Tensor computeQuantizedCat(
598 const std::vector<ArgValue>& inputs,
599 const std::vector<ExprHandle>& outputShape,
600 const std::vector<ExprHandle>& outputStrides,
601 // NOLINTNEXTLINE
602 const std::optional<ScalarType>& outputType,
603 // NOLINTNEXTLINE
604 at::Device device) {
605 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
606 auto inputList = std::get<BufList>(inputs[0]);
607 auto argDim = std::get<int64_t>(inputs[1]);
608 auto n = inputList.size();
609 // TODO: handle optional out_qscale, out_qzero
610 const auto out_qscale = std::get<double>(inputs[2]);
611 const auto out_qzero = std::get<int64_t>(inputs[3]);
612
613 std::vector<BufHandle> args;
614 std::vector<ExprHandle> extra_args;
615 for (const auto i : c10::irange(n)) {
616 const BufHandle& bh = inputList[i];
617 args.emplace_back(bh);
618 extra_args.emplace_back(immQScale(bh));
619 extra_args.emplace_back(immQZero(bh));
620 extra_args.emplace_back((int64_t)immQDType(bh));
621 }
622 extra_args.emplace_back(argDim);
623 extra_args.emplace_back(out_qscale);
624 extra_args.emplace_back(out_qzero);
625 auto ResultBuf = makeQBufHandleContiguous(
626 "quantized_cat",
627 outputShape,
628 Dtype(immQDType(inputList[0])),
629 out_qscale,
630 out_qzero);
631 StmtPtr s =
632 ExternalCall::make(ResultBuf, "nnc_aten_quantized_cat", args, extra_args);
633 return Tensor(ResultBuf.node(), s);
634 }
635
computeDequantize(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)636 Tensor computeDequantize(
637 const std::vector<ArgValue>& inputs,
638 const std::vector<ExprHandle>& outputShape,
639 const std::vector<ExprHandle>& outputStrides,
640 const std::optional<ScalarType>& outputType,
641 at::Device) {
642 Dtype dtype = kFloat;
643 if (outputType) {
644 dtype = Dtype(*outputType);
645 }
646 auto qx = std::get<BufHandle>(inputs[0]);
647 TORCH_INTERNAL_ASSERT(
648 qx.node()->qscale(),
649 buildErrorMessage("Missing quantized scale for dequantize"));
650 TORCH_INTERNAL_ASSERT(
651 qx.node()->qzero(),
652 buildErrorMessage("Missing quantized zero point for dequantize"));
653 auto qscale = ExprHandle(qx.node()->qscale());
654 auto qzero = ExprHandle(qx.node()->qzero());
655 std::vector<VarPtr> vars;
656 std::vector<ExprHandle> indices;
657 for (const auto& os : outputShape) {
658 auto var = alloc<Var>("", os.node()->dtype());
659 vars.push_back(var);
660 indices.push_back(VarHandle(var));
661 }
662 auto y = dequant(tensorOrConstant(inputs[0], indices), dtype, qscale, qzero);
663 BufPtr buf = alloc<Buf>(
664 "dequantize", ExprHandleVectorToExprVector(outputShape), dtype);
665 return Tensor(buf, vars, y.node());
666 }
667
computeUpsampleNearest2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)668 Tensor computeUpsampleNearest2d(
669 const std::vector<ArgValue>& inputs,
670 const std::vector<ExprHandle>& outputShape,
671 const std::vector<ExprHandle>& outputStrides,
672 const std::optional<ScalarType>& outputType,
673 at::Device) {
674 auto A = std::get<BufHandle>(inputs[0]);
675 const auto& output_height = outputShape[2];
676 const auto& output_width = outputShape[3];
677 auto input_height = ExprHandle(A.dim(2));
678 auto input_width = ExprHandle(A.dim(3));
679
680 std::vector<VarHandle> args = create_index_vars(outputShape);
681 // Handle separately when scale is specified? as in 'scalar_t
682 // compute_scales_value' in UpSample.h
683 auto scale_h =
684 promoteToDtype(input_height, ScalarType::Double) / output_height;
685 auto scale_w = promoteToDtype(input_width, ScalarType::Double) / output_width;
686 // TODO: will repetitive if in idx calculation will be taken out of the loop?
687 auto compute_nearest_idx = [](const ExprHandle& scale,
688 const ExprHandle& dst_index,
689 const ExprHandle& input_size) {
690 return Min::make(
691 promoteToDtype(floor(dst_index * scale), ScalarType::Long),
692 input_size - 1,
693 true);
694 };
695 auto body_func = [&](std::vector<VarHandle> axes) {
696 std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
697 newAxes[2] = compute_nearest_idx(scale_h, axes[2], input_height);
698 newAxes[3] = compute_nearest_idx(scale_w, axes[3], input_width);
699 return A.load(newAxes);
700 };
701 auto e = body_func(args);
702 auto strides = isChannelsLast(A) ? make_channels_last_strides(outputShape)
703 : make_contiguous_strides(outputShape);
704 BufHandle buf = Buf::make(
705 "upsample_nearest2d",
706 outputShape,
707 Dtype(*outputType),
708 std::nullopt, // initializer
709 fmap(strides, [&](const ExprPtr& stride) { return ExprHandle(stride); }),
710 ExprHandle(A.node()->qscale()),
711 ExprHandle(A.node()->qzero()));
712 return Tensor(buf, args, e);
713 }
714
computeUpsampleNearest2dExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)715 Tensor computeUpsampleNearest2dExternalCall(
716 const std::vector<ArgValue>& inputs,
717 const std::vector<ExprHandle>& outputShape,
718 const std::vector<ExprHandle>& outputStrides,
719 const std::optional<ScalarType>& outputType,
720 at::Device) {
721 Dtype dtype = kFloat;
722 if (outputType) {
723 dtype = Dtype(*outputType);
724 }
725 int64_t output_size_h = -1;
726 int64_t output_size_w = -1;
727 if (auto output_sizes = std::get_if<IntList>(&inputs[1])) {
728 output_size_h = (*output_sizes)[0];
729 output_size_w = (*output_sizes)[1];
730 }
731
732 double scale_factor_h = -1.f;
733 double scale_factor_w = -1.f;
734 if (auto scale_factors = std::get_if<DoubleList>(&inputs[2])) {
735 scale_factor_h = (*scale_factors)[0];
736 scale_factor_w = (*scale_factors)[1];
737 }
738 const BufHandle& x = std::get<BufHandle>(inputs[0]);
739 double qx_qscale = -1.f;
740 int64_t qx_qzero = -1l;
741 int64_t qx_qdtype = -1l;
742 if (isQuantized(x)) {
743 qx_qscale = immQScale(x);
744 qx_qzero = immQZero(x);
745 qx_qdtype = (int64_t)immQDType(x);
746 }
747
748 BufHandle ResultBuf = [&]() {
749 if (isQuantized(x)) {
750 return makeQBufHandleChannelsLast(
751 "upsample_nearest2d",
752 outputShape,
753 Dtype(immQDType(x)),
754 qx_qscale,
755 qx_qzero);
756 }
757 return BufHandle("upsample_nearest2d", outputShape, dtype);
758 }();
759
760 StmtPtr s = ExternalCall::make(
761 ResultBuf,
762 "nnc_aten_upsample_nearest2d",
763 {x},
764 {qx_qscale,
765 qx_qzero,
766 qx_qdtype,
767 output_size_h,
768 output_size_w,
769 scale_factor_h,
770 scale_factor_w});
771 return Tensor(ResultBuf.node(), s);
772 }
773
computeQuantizedSigmoidExternalCall(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device)774 Tensor computeQuantizedSigmoidExternalCall(
775 const std::vector<ArgValue>& inputs,
776 const std::vector<ExprHandle>& outputShape,
777 const std::vector<ExprHandle>& outputStrides,
778 const std::optional<ScalarType>& outputType,
779 at::Device) {
780 const BufHandle& qx = std::get<BufHandle>(inputs[0]);
781
782 const auto out_qdtype = immQDType(qx);
783 const double out_qscale = 1.0f / 256.0f;
784 const int64_t out_qzero = (out_qdtype == ScalarType::QInt8) ? -128 : 0;
785
786 auto ResultBuf = isChannelsLast(qx) ? makeQBufHandleChannelsLast(
787 "quantized_sigmoid",
788 outputShape,
789 Dtype(out_qdtype),
790 out_qscale,
791 out_qzero)
792 : makeQBufHandleContiguous(
793 "quantized_sigmoid",
794 outputShape,
795 Dtype(out_qdtype),
796 out_qscale,
797 out_qzero);
798 StmtPtr s = ExternalCall::make(
799 ResultBuf,
800 "nnc_aten_quantized_sigmoid",
801 {qx},
802 {immQScale(qx),
803 immQZero(qx),
804 (int64_t)immQDType(qx),
805 out_qscale,
806 out_qzero});
807 return Tensor(ResultBuf.node(), s);
808 }
809
810 } // namespace torch::jit::tensorexpr
811