1 #include <torch/csrc/jit/frontend/function_schema_parser.h>
2 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3 #include <torch/csrc/jit/tensorexpr/lowerings.h>
4 #include <torch/csrc/jit/tensorexpr/operators/operators.h>
5
6 #include <ATen/native/Activation.h>
7 #include <ATen/native/mkldnn/Common.h>
8
9 namespace torch::jit::tensorexpr {
10
getNNCLoweringRegistry()11 FunctionSchemaMap<NNCLoweringFunction>& getNNCLoweringRegistry() {
12 static FunctionSchemaMap<NNCLoweringFunction> lowering_registry_;
13 return lowering_registry_;
14 }
15
RegisterNNCLoweringsFunction(const std::vector<std::string> & schemas,const NNCLoweringFunction & fn)16 RegisterNNCLoweringsFunction::RegisterNNCLoweringsFunction(
17 const std::vector<std::string>& schemas,
18 const NNCLoweringFunction& fn) {
19 for (const auto& schema_str : schemas) {
20 getNNCLoweringRegistry().insert(parseSchema(schema_str), fn);
21 }
22 }
23
24 namespace {
25 // NOLINTNEXTLINE
nnc_lowerings_lazy_registration()26 int nnc_lowerings_lazy_registration() {
27 RegisterNNCLoweringsFunction aten_dropout(
28 {"aten::dropout(Tensor input, float p, bool train) -> (Tensor)"},
29 computeNoop);
30 RegisterNNCLoweringsFunction aten_contiguous(
31 {"aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> (Tensor(a))"},
32 computeNoop);
33
34 #ifdef USE_XNNPACK
35 // TODO: add a test
36 RegisterNNCLoweringsFunction prepacked_conv2d_clamp_run(
37 {"prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> (Tensor Y)"},
38 computePrepackedConv2dClampRun);
39
40 // TODO: add a test
41 RegisterNNCLoweringsFunction prepacked_linear_clamp_run(
42 {"prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> (Tensor Y)"},
43 computePrepackedLinearClampRun);
44 #endif
45
46 #if AT_MKLDNN_ENABLED()
47 RegisterNNCLoweringsFunction mkldnn_prepacked_conv2d_run(
48 {"mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> (Tensor Y)"},
49 computeMkldnnPrepackedConvRun);
50 #endif // AT_MKLDNN_ENABLED()
51
52 RegisterNNCLoweringsFunction aten_sub(
53 {"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
54 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)"},
55 [](const std::vector<ArgValue>& inputs,
56 const std::vector<ExprHandle>& outputShape,
57 const std::vector<ExprHandle>& outputStrides,
58 const std::optional<ScalarType>& outputType,
59 at::Device device) {
60 auto sub_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) {
61 // NB: sub isn't supported on boolean, no need to promote to integer.
62 return lhs - rhs;
63 };
64 TORCH_INTERNAL_ASSERT(
65 inputs.size() == 2 || inputs.size() == 3,
66 buildErrorMessage("Invalid number of input operands"));
67 return (inputs.size() > 2) ? computeTwoOperandWithAlpha(
68 "aten_sub",
69 inputs,
70 outputShape,
71 outputStrides,
72 outputType,
73 sub_lambda)
74 : computeTwoOperand(
75 "aten_sub",
76 inputs,
77 outputShape,
78 outputStrides,
79 outputType,
80 sub_lambda);
81 });
82
83 RegisterNNCLoweringsFunction aten_mul(
84 {"aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)",
85 "aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)"},
86 [](const std::vector<ArgValue>& inputs,
87 const std::vector<ExprHandle>& outputShape,
88 const std::vector<ExprHandle>& outputStrides,
89 const std::optional<ScalarType>& outputType,
90 at::Device device) {
91 return computeTwoOperand(
92 "aten_mul",
93 inputs,
94 outputShape,
95 outputStrides,
96 outputType,
97 [](const ExprHandle& lhs, const ExprHandle& rhs) {
98 return boolToInteger(lhs) * boolToInteger(rhs);
99 });
100 });
101
102 #define DEFINE_BINARY_SCALAR_OP_LOWERING(op_name, op) \
103 RegisterNNCLoweringsFunction aten_##op_name##_scalar( \
104 {"aten::" #op_name ".int(int a, int b) -> (int)", \
105 "aten::" #op_name ".int_float(int a, float b) -> (float)", \
106 "aten::" #op_name ".float_int(float a, int b) -> (float)", \
107 "aten::" #op_name ".float(float a, float b) -> (float)"}, \
108 [](const std::vector<ArgValue>& inputs, \
109 const std::vector<ExprHandle>& outputShape, \
110 const std::vector<ExprHandle>& outputStrides, \
111 const std::optional<ScalarType>& outputType, \
112 at::Device device) { \
113 return computeScalar( \
114 "aten_#op_name", \
115 inputs, \
116 outputShape, \
117 outputStrides, \
118 outputType, \
119 [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
120 });
121 DEFINE_BINARY_SCALAR_OP_LOWERING(mul, a * b)
122 DEFINE_BINARY_SCALAR_OP_LOWERING(add, a + b)
123 DEFINE_BINARY_SCALAR_OP_LOWERING(sub, a - b)
124 #undef DEFINE_BINARY_SCALAR_OP_LOWERING
125 RegisterNNCLoweringsFunction aten_div_scalar(
126 {"aten::div(Scalar a, Scalar b) -> (float)",
127 "aten::div.int(int a, int b) -> (float)",
128 "aten::div.int_float(int a, float b) -> (float)",
129 "aten::div.float_int(float a, int b) -> (float)",
130 "aten::div.float(float a, float b) -> (float)"},
131 [](const std::vector<ArgValue>& inputs,
132 const std::vector<ExprHandle>& outputShape,
133 const std::vector<ExprHandle>& outputStrides,
134 const std::optional<ScalarType>& outputType,
135 at::Device device) {
136 return computeScalar(
137 "aten_div",
138 inputs,
139 outputShape,
140 outputStrides,
141 outputType,
142 [](const ExprHandle& a, const ExprHandle& b) {
143 return promoteIntegerToDefaultType(a) /
144 promoteIntegerToDefaultType(b);
145 });
146 });
147
148 #define DEFINE_COMPARISON_SCALAR_OP_LOWERING(op_name, op) \
149 RegisterNNCLoweringsFunction aten_##op_name##_scalar( \
150 {"aten::" #op_name ".bool(bool a, bool b) -> (bool)", \
151 "aten::" #op_name ".int(int a, int b) -> (bool)", \
152 "aten::" #op_name ".int_float(int a, float b) -> (bool)", \
153 "aten::" #op_name ".float_int(float a, int b) -> (bool)", \
154 "aten::" #op_name ".float(float a, float b) -> (bool)"}, \
155 [](const std::vector<ArgValue>& inputs, \
156 const std::vector<ExprHandle>& outputShape, \
157 const std::vector<ExprHandle>& outputStrides, \
158 const std::optional<ScalarType>& outputType, \
159 at::Device device) { \
160 return computeScalar( \
161 "aten_#op_name", \
162 inputs, \
163 outputShape, \
164 outputStrides, \
165 outputType, \
166 [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
167 });
168 DEFINE_COMPARISON_SCALAR_OP_LOWERING(lt, cast<bool>(a < b))
169 DEFINE_COMPARISON_SCALAR_OP_LOWERING(le, cast<bool>(a <= b))
170 DEFINE_COMPARISON_SCALAR_OP_LOWERING(eq, cast<bool>(a == b))
171 DEFINE_COMPARISON_SCALAR_OP_LOWERING(ne, cast<bool>(a != b))
172 DEFINE_COMPARISON_SCALAR_OP_LOWERING(gt, cast<bool>(a > b))
173 DEFINE_COMPARISON_SCALAR_OP_LOWERING(ge, cast<bool>(a >= b))
174 #undef DEFINE_COMPARISON_SCALAR_OP_LOWERING
175
176 #define DEFINE_BITWISE_SCALAR_OP_LOWERING(op_name, op) \
177 RegisterNNCLoweringsFunction aten_##op_name##_int_scalar( \
178 {"aten::" #op_name ".int(int a, int b) -> (int)"}, \
179 [](const std::vector<ArgValue>& inputs, \
180 const std::vector<ExprHandle>& outputShape, \
181 const std::vector<ExprHandle>& outputStrides, \
182 const std::optional<ScalarType>& outputType, \
183 at::Device device) { \
184 return computeScalar( \
185 "aten_#op_name", \
186 inputs, \
187 outputShape, \
188 outputStrides, \
189 outputType, \
190 [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
191 });
192 DEFINE_BITWISE_SCALAR_OP_LOWERING(
193 __and__, boolToInteger(a) & boolToInteger(b))
194 DEFINE_BITWISE_SCALAR_OP_LOWERING(__or__, boolToInteger(a) | boolToInteger(b))
195 DEFINE_BITWISE_SCALAR_OP_LOWERING(
196 __xor__, boolToInteger(a) ^ boolToInteger(b))
197 DEFINE_BITWISE_SCALAR_OP_LOWERING(__lshift__, a << b)
198 DEFINE_BITWISE_SCALAR_OP_LOWERING(__rshift__, a >> b)
199 #undef DEFINE_BITWISE_SCALAR_OP_LOWERING
200
201 #define DEFINE_LOGICAL_SCALAR_OP_LOWERING(op_name, op) \
202 RegisterNNCLoweringsFunction aten_##op_name##_bool_scalar( \
203 {"aten::" #op_name ".bool(bool a, bool b) -> (bool)"}, \
204 [](const std::vector<ArgValue>& inputs, \
205 const std::vector<ExprHandle>& outputShape, \
206 const std::vector<ExprHandle>& outputStrides, \
207 const std::optional<ScalarType>& outputType, \
208 at::Device device) { \
209 return computeScalar( \
210 "aten_#op_name", \
211 inputs, \
212 outputShape, \
213 outputStrides, \
214 outputType, \
215 [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
216 });
217 DEFINE_LOGICAL_SCALAR_OP_LOWERING(__and__, a && b)
218 DEFINE_LOGICAL_SCALAR_OP_LOWERING(__or__, a || b)
219 DEFINE_LOGICAL_SCALAR_OP_LOWERING(__xor__, a != b)
220 #undef DEFINE_LOGICAL_SCALAR_OP_LOWERING
221
222 RegisterNNCLoweringsFunction aten_div(
223 {"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)",
224 "aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)"},
225 [](const std::vector<ArgValue>& inputs,
226 const std::vector<ExprHandle>& outputShape,
227 const std::vector<ExprHandle>& outputStrides,
228 const std::optional<ScalarType>& outputType,
229 at::Device device) {
230 return computeTwoOperand(
231 "aten_div",
232 inputs,
233 outputShape,
234 outputStrides,
235 outputType,
236 [](const ExprHandle& lhs, const ExprHandle& rhs) {
237 return promoteIntegerToDefaultType(lhs) /
238 promoteIntegerToDefaultType(rhs);
239 });
240 });
241
242 RegisterNNCLoweringsFunction aten___and__(
243 {"aten::__and__.Scalar(Tensor self, Scalar other) -> (Tensor)",
244 "aten::__and__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
245 [](const std::vector<ArgValue>& inputs,
246 const std::vector<ExprHandle>& outputShape,
247 const std::vector<ExprHandle>& outputStrides,
248 const std::optional<ScalarType>& outputType,
249 at::Device device) {
250 return computeTwoOperand(
251 "aten_and",
252 inputs,
253 outputShape,
254 outputStrides,
255 outputType,
256 [](const ExprHandle& lhs, const ExprHandle& rhs) {
257 return boolToInteger(lhs) & boolToInteger(rhs);
258 });
259 });
260
261 RegisterNNCLoweringsFunction aten___or__(
262 {"aten::__or__.Scalar(Tensor self, Scalar other) -> (Tensor)",
263 "aten::__or__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
264 [](const std::vector<ArgValue>& inputs,
265 const std::vector<ExprHandle>& outputShape,
266 const std::vector<ExprHandle>& outputStrides,
267 const std::optional<ScalarType>& outputType,
268 at::Device device) {
269 return computeTwoOperand(
270 "aten_or",
271 inputs,
272 outputShape,
273 outputStrides,
274 outputType,
275 [](const ExprHandle& lhs, const ExprHandle& rhs) {
276 return boolToInteger(lhs) | boolToInteger(rhs);
277 });
278 });
279
280 RegisterNNCLoweringsFunction aten___xor__(
281 {"aten::__xor__.Scalar(Tensor self, Scalar other) -> (Tensor)",
282 "aten::__xor__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
283 [](const std::vector<ArgValue>& inputs,
284 const std::vector<ExprHandle>& outputShape,
285 const std::vector<ExprHandle>& outputStrides,
286 const std::optional<ScalarType>& outputType,
287 at::Device device) {
288 return computeTwoOperand(
289 "aten_xor",
290 inputs,
291 outputShape,
292 outputStrides,
293 outputType,
294 [](const ExprHandle& lhs, const ExprHandle& rhs) {
295 return boolToInteger(lhs) ^ boolToInteger(rhs);
296 });
297 });
298
299 RegisterNNCLoweringsFunction aten___lshift__(
300 {"aten::__lshift__.Scalar(Tensor self, Scalar other) -> (Tensor)",
301 "aten::__lshift__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
302 [](const std::vector<ArgValue>& inputs,
303 const std::vector<ExprHandle>& outputShape,
304 const std::vector<ExprHandle>& outputStrides,
305 const std::optional<ScalarType>& outputType,
306 at::Device device) {
307 return computeTwoOperand(
308 "aten_lshift",
309 inputs,
310 outputShape,
311 outputStrides,
312 outputType,
313 [](const ExprHandle& lhs, const ExprHandle& rhs) {
314 return lhs << rhs;
315 });
316 });
317
318 RegisterNNCLoweringsFunction aten___rshift__(
319 {"aten::__rshift__.Scalar(Tensor self, Scalar other) -> (Tensor)",
320 "aten::__rshift__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
321 [](const std::vector<ArgValue>& inputs,
322 const std::vector<ExprHandle>& outputShape,
323 const std::vector<ExprHandle>& outputStrides,
324 const std::optional<ScalarType>& outputType,
325 at::Device device) {
326 return computeTwoOperand(
327 "aten_rshift",
328 inputs,
329 outputShape,
330 outputStrides,
331 outputType,
332 [](const ExprHandle& lhs, const ExprHandle& rhs) {
333 return lhs >> rhs;
334 });
335 });
336
337 RegisterNNCLoweringsFunction aten_eq(
338 {"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)",
339 "aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)"},
340 [](const std::vector<ArgValue>& inputs,
341 const std::vector<ExprHandle>& outputShape,
342 const std::vector<ExprHandle>& outputStrides,
343 const std::optional<ScalarType>& outputType,
344 at::Device device) {
345 return computeTwoOperand(
346 "aten_eq",
347 inputs,
348 outputShape,
349 outputStrides,
350 outputType,
351 [](const ExprHandle& lhs, const ExprHandle& rhs) {
352 return cast<bool>(lhs == rhs);
353 });
354 });
355
356 RegisterNNCLoweringsFunction aten_ne(
357 {"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
358 "aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)"},
359 [](const std::vector<ArgValue>& inputs,
360 const std::vector<ExprHandle>& outputShape,
361 const std::vector<ExprHandle>& outputStrides,
362 const std::optional<ScalarType>& outputType,
363 at::Device device) {
364 return computeTwoOperand(
365 "aten_ne",
366 inputs,
367 outputShape,
368 outputStrides,
369 outputType,
370 [](const ExprHandle& lhs, const ExprHandle& rhs) {
371 return cast<bool>(lhs != rhs);
372 });
373 });
374
375 RegisterNNCLoweringsFunction aten_ge(
376 {"aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)",
377 "aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)"},
378 [](const std::vector<ArgValue>& inputs,
379 const std::vector<ExprHandle>& outputShape,
380 const std::vector<ExprHandle>& outputStrides,
381 const std::optional<ScalarType>& outputType,
382 at::Device device) {
383 return computeTwoOperand(
384 "aten_ge",
385 inputs,
386 outputShape,
387 outputStrides,
388 outputType,
389 [](const ExprHandle& lhs, const ExprHandle& rhs) {
390 return cast<bool>(lhs >= rhs);
391 });
392 });
393
394 RegisterNNCLoweringsFunction aten_gt(
395 {"aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)",
396 "aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)"},
397 [](const std::vector<ArgValue>& inputs,
398 const std::vector<ExprHandle>& outputShape,
399 const std::vector<ExprHandle>& outputStrides,
400 const std::optional<ScalarType>& outputType,
401 at::Device device) {
402 return computeTwoOperand(
403 "aten_gt",
404 inputs,
405 outputShape,
406 outputStrides,
407 outputType,
408 [](const ExprHandle& lhs, const ExprHandle& rhs) {
409 return cast<bool>(lhs > rhs);
410 });
411 });
412
413 RegisterNNCLoweringsFunction aten_le(
414 {"aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)",
415 "aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)"},
416 [](const std::vector<ArgValue>& inputs,
417 const std::vector<ExprHandle>& outputShape,
418 const std::vector<ExprHandle>& outputStrides,
419 const std::optional<ScalarType>& outputType,
420 at::Device device) {
421 return computeTwoOperand(
422 "aten_le",
423 inputs,
424 outputShape,
425 outputStrides,
426 outputType,
427 [](const ExprHandle& lhs, const ExprHandle& rhs) {
428 return cast<bool>(lhs <= rhs);
429 });
430 });
431
432 RegisterNNCLoweringsFunction aten_lt(
433 {"aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)",
434 "aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)"},
435 [](const std::vector<ArgValue>& inputs,
436 const std::vector<ExprHandle>& outputShape,
437 const std::vector<ExprHandle>& outputStrides,
438 const std::optional<ScalarType>& outputType,
439 at::Device device) {
440 return computeTwoOperand(
441 "aten_lt",
442 inputs,
443 outputShape,
444 outputStrides,
445 outputType,
446 [](const ExprHandle& lhs, const ExprHandle& rhs) {
447 return cast<bool>(lhs < rhs);
448 });
449 });
450
451 RegisterNNCLoweringsFunction aten_min_pointwise(
452 {"aten::min.other(Tensor self, Tensor other) -> (Tensor)"},
453 [](const std::vector<ArgValue>& inputs,
454 const std::vector<ExprHandle>& outputShape,
455 const std::vector<ExprHandle>& outputStrides,
456 const std::optional<ScalarType>& outputType,
457 at::Device device) {
458 return computeTwoOperand(
459 "aten_min",
460 inputs,
461 outputShape,
462 outputStrides,
463 outputType,
464 [](const ExprHandle& lhs, const ExprHandle& rhs) {
465 return Min::make(boolToInteger(lhs), boolToInteger(rhs), false);
466 });
467 });
468
469 RegisterNNCLoweringsFunction aten_max_pointwise(
470 {"aten::max.other(Tensor self, Tensor other) -> (Tensor)"},
471 [](const std::vector<ArgValue>& inputs,
472 const std::vector<ExprHandle>& outputShape,
473 const std::vector<ExprHandle>& outputStrides,
474 const std::optional<ScalarType>& outputType,
475 at::Device device) {
476 return computeTwoOperand(
477 "aten_max",
478 inputs,
479 outputShape,
480 outputStrides,
481 outputType,
482 [](const ExprHandle& lhs, const ExprHandle& rhs) {
483 return Max::make(boolToInteger(lhs), boolToInteger(rhs), false);
484 });
485 });
486
487 RegisterNNCLoweringsFunction aten_masked_fill(
488 {"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
489 "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> (Tensor)"},
490 [](const std::vector<ArgValue>& inputs,
491 const std::vector<ExprHandle>& outputShape,
492 const std::vector<ExprHandle>& outputStrides,
493 const std::optional<ScalarType>& outputType,
494 at::Device device) {
495 return computeThreeOperand(
496 "aten_masked_fill",
497 inputs,
498 outputShape,
499 outputStrides,
500 outputType,
501 [](const ExprHandle& input,
502 const ExprHandle& mask,
503 const ExprHandle& value) {
504 // value needs to promote to input, not vice versa
505 auto val = promoteToDtype(value, input.dtype().scalar_type());
506 return ifThenElse(mask, val, input);
507 },
508 /*promote_inputs*/ false);
509 });
510 RegisterNNCLoweringsFunction aten_clamp(
511 {"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)",
512 "aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> (Tensor)"},
513 [](const std::vector<ArgValue>& inputs,
514 const std::vector<ExprHandle>& outputShape,
515 const std::vector<ExprHandle>& outputStrides,
516 const std::optional<ScalarType>& outputType,
517 at::Device device) {
518 bool noMin = false;
519 bool noMax = false;
520 if (std::get_if<ArgNone>(&inputs[1])) {
521 noMin = true;
522 }
523
524 if (std::get_if<ArgNone>(&inputs[2])) {
525 noMax = true;
526 }
527
528 return computeThreeOperand(
529 "aten_clamp",
530 inputs,
531 outputShape,
532 outputStrides,
533 outputType,
534 [noMin, noMax](
535 const ExprHandle& in,
536 const ExprHandle& min,
537 const ExprHandle& max) {
538 auto cast = [&](const ExprHandle& e) {
539 return Cast::make(in.dtype(), e);
540 };
541
542 if (noMin && noMax) {
543 return in;
544 } else if (noMin) {
545 auto cmax = cast(max);
546 return CompareSelect::make(in, cmax, cmax, in, kGT);
547 } else if (noMax) {
548 auto cmin = cast(min);
549 return CompareSelect::make(in, cmin, cmin, in, kLT);
550 } else {
551 auto cmax = cast(max);
552 auto cmin = cast(min);
553 return clamp(cmin, cmax, in);
554 }
555 },
556 false /* promote_inputs */);
557 });
558
559 RegisterNNCLoweringsFunction aten_addcmul(
560 {"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> (Tensor)"},
561 [](const std::vector<ArgValue>& inputs,
562 const std::vector<ExprHandle>& outputShape,
563 const std::vector<ExprHandle>& outputStrides,
564 const std::optional<ScalarType>& outputType,
565 at::Device device) {
566 return computeFourOperand(
567 "aten_addcmul",
568 inputs,
569 outputShape,
570 outputStrides,
571 outputType,
572 [](const ExprHandle& a0,
573 const ExprHandle& a1,
574 const ExprHandle& a2,
575 const ExprHandle& a3) { return a0 + a3 * a1 * a2; });
576 });
577
578 RegisterNNCLoweringsFunction aten_sigmoid(
579 {"aten::sigmoid(Tensor self) -> (Tensor)"},
580 [](const std::vector<ArgValue>& inputs,
581 const std::vector<ExprHandle>& outputShape,
582 const std::vector<ExprHandle>& outputStrides,
583 const std::optional<ScalarType>& outputType,
584 at::Device device) {
585 // check if the activation is quantized
586 const BufHandle& x = std::get<BufHandle>(inputs[0]);
587 if (x.node()->qscale()) {
588 return computeQuantizedSigmoidExternalCall(
589 inputs, outputShape, outputStrides, outputType, device);
590 }
591 return computeOneOperand(
592 "aten_sigmoid",
593 inputs,
594 outputShape,
595 outputStrides,
596 outputType,
597 [](const ExprHandle& a) {
598 return sigmoid(promoteIntegerToDefaultType(a));
599 });
600 });
601
602 RegisterNNCLoweringsFunction aten_silu(
603 {"aten::silu(Tensor self) -> (Tensor)"},
604 [](const std::vector<ArgValue>& inputs,
605 const std::vector<ExprHandle>& outputShape,
606 const std::vector<ExprHandle>& outputStrides,
607 const std::optional<ScalarType>& outputType,
608 at::Device device) {
609 return computeOneOperand(
610 "aten_silu",
611 inputs,
612 outputShape,
613 outputStrides,
614 outputType,
615 [](const ExprHandle& a) { return a * sigmoid(a); });
616 });
617
618 RegisterNNCLoweringsFunction aten_reciprocal(
619 {"aten::reciprocal(Tensor self) -> (Tensor)"},
620 [](const std::vector<ArgValue>& inputs,
621 const std::vector<ExprHandle>& outputShape,
622 const std::vector<ExprHandle>& outputStrides,
623 const std::optional<ScalarType>& outputType,
624 at::Device device) {
625 return computeOneOperand(
626 "aten_reciprocal",
627 inputs,
628 outputShape,
629 outputStrides,
630 outputType,
631 [](const ExprHandle& a) { return ExprHandle(1.0f) / a; });
632 });
633
634 RegisterNNCLoweringsFunction aten_neg(
635 {"aten::neg(Tensor self) -> (Tensor)"},
636 [](const std::vector<ArgValue>& inputs,
637 const std::vector<ExprHandle>& outputShape,
638 const std::vector<ExprHandle>& outputStrides,
639 const std::optional<ScalarType>& outputType,
640 at::Device device) {
641 return computeOneOperand(
642 "aten_neg",
643 inputs,
644 outputShape,
645 outputStrides,
646 outputType,
647 [](const ExprHandle& a) { return ExprHandle(-0) - a; });
648 });
649
650 RegisterNNCLoweringsFunction aten_isnan(
651 {"aten::isnan(Tensor self) -> (Tensor)"},
652 [](const std::vector<ArgValue>& inputs,
653 const std::vector<ExprHandle>& outputShape,
654 const std::vector<ExprHandle>& outputStrides,
655 const std::optional<ScalarType>& outputType,
656 at::Device device) {
657 return computeOneOperand(
658 "aten_isnan",
659 inputs,
660 outputShape,
661 outputStrides,
662 outputType,
663 [](const ExprHandle& a) {
664 if (!a.dtype().is_floating_point()) {
665 return IntImm::make(0);
666 }
667 return isnan(a);
668 });
669 });
670
671 RegisterNNCLoweringsFunction aten_relu(
672 {"aten::relu(Tensor self) -> (Tensor)"},
673 [](const std::vector<ArgValue>& inputs,
674 const std::vector<ExprHandle>& outputShape,
675 const std::vector<ExprHandle>& outputStrides,
676 const std::optional<ScalarType>& outputType,
677 at::Device device) {
678 auto A = std::get<BufHandle>(inputs[0]);
679 if (A.node()->qscale()) {
680 return computeQuantizedRelu(
681 inputs, outputShape, outputStrides, outputType, device);
682 }
683 return computeOneOperand(
684 "aten_relu",
685 inputs,
686 outputShape,
687 outputStrides,
688 outputType,
689 [](const ExprHandle& a) {
690 auto zero = Cast::make(a.dtype(), 0);
691 return CompareSelect::make(a, zero, zero, a, kLT);
692 });
693 });
694
695 RegisterNNCLoweringsFunction aten_leaky_relu(
696 {"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)"},
697 [](const std::vector<ArgValue>& inputs,
698 const std::vector<ExprHandle>& outputShape,
699 const std::vector<ExprHandle>& outputStrides,
700 const std::optional<ScalarType>& outputType,
701 at::Device device) {
702 return computeTwoOperand(
703 "aten_leaky_relu",
704 inputs,
705 outputShape,
706 outputStrides,
707 outputType,
708 [](const ExprHandle& a, const ExprHandle& negative_slope) {
709 auto neg_slope = Cast::make(a.dtype(), negative_slope);
710 auto zero = Cast::make(a.dtype(), 0);
711 auto one = Cast::make(a.dtype(), 1);
712 auto cs = CompareSelect::make(a, zero, one, neg_slope, kGT);
713 return a * cs;
714 });
715 });
716
717 RegisterNNCLoweringsFunction aten_relu6(
718 {"aten::relu6(Tensor self) -> (Tensor)"},
719 [](const std::vector<ArgValue>& inputs,
720 const std::vector<ExprHandle>& outputShape,
721 const std::vector<ExprHandle>& outputStrides,
722 const std::optional<ScalarType>& outputType,
723 at::Device device) {
724 return computeOneOperand(
725 "aten_relu6",
726 inputs,
727 outputShape,
728 outputStrides,
729 outputType,
730 [](const ExprHandle& a) {
731 auto zero = Cast::make(a.dtype(), 0);
732 auto six = Cast::make(a.dtype(), 6.);
733 return clamp(zero, six, a);
734 });
735 });
736
737 RegisterNNCLoweringsFunction aten_gelu(
738 {"aten::gelu(Tensor self, *, str approximate='none') -> (Tensor)"},
739 [](const std::vector<ArgValue>& inputs,
740 const std::vector<ExprHandle>& outputShape,
741 const std::vector<ExprHandle>& outputStrides,
742 const std::optional<ScalarType>& outputType,
743 at::Device device) {
744 const auto& kApproximate = std::get<std::string>(inputs[1]);
745 std::vector<ArgValue> operands = {inputs.front()};
746 if (at::native::get_gelutype_enum(kApproximate) ==
747 at::native::GeluType::Tanh) {
748 // approximate == 'tanh'
749 return computeOneOperand(
750 "aten_tanh_gelu",
751 operands,
752 outputShape,
753 outputStrides,
754 outputType,
755 [](const ExprHandle& a) {
756 auto one = Cast::make(a.dtype(), 1.);
757 auto point_five = Cast::make(a.dtype(), .5);
758 auto beta = Cast::make(a.dtype(), M_SQRT2 * M_2_SQRTPI * 0.5);
759 auto kappa = Cast::make(a.dtype(), 0.044715);
760 auto a_cube = a * a * a;
761 auto inner = beta * (a + kappa * a_cube);
762 return point_five * a * (one + tanh(inner));
763 });
764 } else {
765 // approximate == 'none'
766 return computeOneOperand(
767 "aten_gelu",
768 operands,
769 outputShape,
770 outputStrides,
771 outputType,
772 [](const ExprHandle& a) {
773 auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2);
774 auto one = Cast::make(a.dtype(), 1.);
775 auto point_five = Cast::make(a.dtype(), .5);
776 return a * point_five * (one + erf(a * m_sqrt1_2));
777 });
778 }
779 });
780
781 RegisterNNCLoweringsFunction aten_batch_norm(
782 {"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)"},
783 computeBatchNorm);
784
785 RegisterNNCLoweringsFunction aten_log(
786 {"aten::log(Tensor self) -> (Tensor)"},
787 [](const std::vector<ArgValue>& inputs,
788 const std::vector<ExprHandle>& outputShape,
789 const std::vector<ExprHandle>& outputStrides,
790 const std::optional<ScalarType>& outputType,
791 at::Device device) {
792 return computeOneOperand(
793 "aten_log",
794 inputs,
795 outputShape,
796 outputStrides,
797 outputType,
798 [](const ExprHandle& a) {
799 return log(promoteIntegerToDefaultType(a));
800 });
801 });
802
803 RegisterNNCLoweringsFunction aten_log10(
804 {"aten::log10(Tensor self) -> (Tensor)"},
805 [](const std::vector<ArgValue>& inputs,
806 const std::vector<ExprHandle>& outputShape,
807 const std::vector<ExprHandle>& outputStrides,
808 const std::optional<ScalarType>& outputType,
809 at::Device device) {
810 return computeOneOperand(
811 "aten_log10",
812 inputs,
813 outputShape,
814 outputStrides,
815 outputType,
816 [](const ExprHandle& a) {
817 return log10(promoteIntegerToDefaultType(a));
818 });
819 });
820
821 RegisterNNCLoweringsFunction aten_log1p(
822 {"aten::log1p(Tensor self) -> (Tensor)"},
823 [](const std::vector<ArgValue>& inputs,
824 const std::vector<ExprHandle>& outputShape,
825 const std::vector<ExprHandle>& outputStrides,
826 const std::optional<ScalarType>& outputType,
827 at::Device device) {
828 return computeOneOperand(
829 "aten_log1p",
830 inputs,
831 outputShape,
832 outputStrides,
833 outputType,
834 [](const ExprHandle& a) {
835 return log1p(promoteIntegerToDefaultType(a));
836 });
837 });
838
839 RegisterNNCLoweringsFunction aten_log2(
840 {"aten::log2(Tensor self) -> (Tensor)"},
841 [](const std::vector<ArgValue>& inputs,
842 const std::vector<ExprHandle>& outputShape,
843 const std::vector<ExprHandle>& outputStrides,
844 const std::optional<ScalarType>& outputType,
845 at::Device device) {
846 return computeOneOperand(
847 "aten_log2",
848 inputs,
849 outputShape,
850 outputStrides,
851 outputType,
852 [](const ExprHandle& a) {
853 return log2(promoteIntegerToDefaultType(a));
854 });
855 });
856
857 RegisterNNCLoweringsFunction aten_exp(
858 {"aten::exp(Tensor self) -> (Tensor)"},
859 [](const std::vector<ArgValue>& inputs,
860 const std::vector<ExprHandle>& outputShape,
861 const std::vector<ExprHandle>& outputStrides,
862 const std::optional<ScalarType>& outputType,
863 at::Device device) {
864 return computeOneOperand(
865 "aten_exp",
866 inputs,
867 outputShape,
868 outputStrides,
869 outputType,
870 [](const ExprHandle& a) {
871 return exp(promoteIntegerToDefaultType(a));
872 });
873 });
874
875 RegisterNNCLoweringsFunction aten_expm1(
876 {"aten::expm1(Tensor self) -> (Tensor)"},
877 [](const std::vector<ArgValue>& inputs,
878 const std::vector<ExprHandle>& outputShape,
879 const std::vector<ExprHandle>& outputStrides,
880 const std::optional<ScalarType>& outputType,
881 at::Device device) {
882 return computeOneOperand(
883 "aten_expm1",
884 inputs,
885 outputShape,
886 outputStrides,
887 outputType,
888 [](const ExprHandle& a) {
889 return expm1(promoteIntegerToDefaultType(a));
890 });
891 });
892
893 RegisterNNCLoweringsFunction aten_erf(
894 {"aten::erf(Tensor self) -> (Tensor)"},
895 [](const std::vector<ArgValue>& inputs,
896 const std::vector<ExprHandle>& outputShape,
897 const std::vector<ExprHandle>& outputStrides,
898 const std::optional<ScalarType>& outputType,
899 at::Device device) {
900 return computeOneOperand(
901 "aten_erf",
902 inputs,
903 outputShape,
904 outputStrides,
905 outputType,
906 [](const ExprHandle& a) {
907 return erf(promoteIntegerToDefaultType(a));
908 });
909 });
910
911 RegisterNNCLoweringsFunction aten_erfc(
912 {"aten::erfc(Tensor self) -> (Tensor)"},
913 [](const std::vector<ArgValue>& inputs,
914 const std::vector<ExprHandle>& outputShape,
915 const std::vector<ExprHandle>& outputStrides,
916 const std::optional<ScalarType>& outputType,
917 at::Device device) {
918 return computeOneOperand(
919 "aten_erfc",
920 inputs,
921 outputShape,
922 outputStrides,
923 outputType,
924 [](const ExprHandle& a) {
925 return erfc(promoteIntegerToDefaultType(a));
926 });
927 });
928
929 RegisterNNCLoweringsFunction aten_cos(
930 {"aten::cos(Tensor self) -> (Tensor)"},
931 [](const std::vector<ArgValue>& inputs,
932 const std::vector<ExprHandle>& outputShape,
933 const std::vector<ExprHandle>& outputStrides,
934 const std::optional<ScalarType>& outputType,
935 at::Device device) {
936 return computeOneOperand(
937 "aten_cos",
938 inputs,
939 outputShape,
940 outputStrides,
941 outputType,
942 [](const ExprHandle& a) {
943 return cos(promoteIntegerToDefaultType(a));
944 });
945 });
946
947 RegisterNNCLoweringsFunction aten_sin(
948 {"aten::sin(Tensor self) -> (Tensor)"},
949 [](const std::vector<ArgValue>& inputs,
950 const std::vector<ExprHandle>& outputShape,
951 const std::vector<ExprHandle>& outputStrides,
952 const std::optional<ScalarType>& outputType,
953 at::Device device) {
954 return computeOneOperand(
955 "aten_sin",
956 inputs,
957 outputShape,
958 outputStrides,
959 outputType,
960 [](const ExprHandle& a) {
961 return sin(promoteIntegerToDefaultType(a));
962 });
963 });
964
965 RegisterNNCLoweringsFunction aten_tan(
966 {"aten::tan(Tensor self) -> (Tensor)"},
967 [](const std::vector<ArgValue>& inputs,
968 const std::vector<ExprHandle>& outputShape,
969 const std::vector<ExprHandle>& outputStrides,
970 const std::optional<ScalarType>& outputType,
971 at::Device device) {
972 return computeOneOperand(
973 "aten_tan",
974 inputs,
975 outputShape,
976 outputStrides,
977 outputType,
978 [](const ExprHandle& a) {
979 return tan(promoteIntegerToDefaultType(a));
980 });
981 });
982
983 RegisterNNCLoweringsFunction aten_type_as(
984 {"aten::type_as(Tensor self, Tensor other) -> (Tensor)"},
985 [](const std::vector<ArgValue>& inputs,
986 const std::vector<ExprHandle>& outputShape,
987 const std::vector<ExprHandle>& outputStrides,
988 const std::optional<ScalarType>& outputType,
989 at::Device device) {
990 const BufHandle& rhs = std::get<BufHandle>(inputs[1]);
991 auto dtype = rhs.dtype();
992 return computeOneOperand(
993 "aten_type_as",
994 inputs,
995 outputShape,
996 outputStrides,
997 outputType,
998 [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); });
999 });
1000
1001 RegisterNNCLoweringsFunction aten_pow(
1002 {"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
1003 "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)",
1004 "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor"},
1005 [](const std::vector<ArgValue>& inputs,
1006 const std::vector<ExprHandle>& outputShape,
1007 const std::vector<ExprHandle>& outputStrides,
1008 const std::optional<ScalarType>& outputType,
1009 at::Device device) {
1010 return computeTwoOperand(
1011 "aten_pow",
1012 inputs,
1013 outputShape,
1014 outputStrides,
1015 outputType,
1016 [](const ExprHandle& lhs, const ExprHandle& rhs) {
1017 if (!rhs.node()->isConstant()) {
1018 return pow(lhs, rhs);
1019 }
1020 double val =
1021 immediateAs<double>(IRSimplifier::simplify(rhs.node()));
1022
1023 if (val == 1.0f) {
1024 return lhs;
1025 } else if (val == 2.0f) { // NOLINT
1026 return lhs * lhs;
1027 } else if (val == 3.0f) { // NOLINT
1028 return (lhs * lhs) * lhs;
1029 } else if (val == 4.0f) { // NOLINT
1030 ExprHandle tmp = lhs * lhs;
1031 return tmp * tmp;
1032 } else if (val == 0.5f) { // NOLINT
1033 return sqrt(lhs);
1034 } else if (val == 0.0f) {
1035 return ExprHandle(1.0f);
1036 } else if (val == -0.5f) { // NOLINT
1037 return rsqrt(lhs);
1038 } else if (val == -1.0f) {
1039 return ExprHandle(1.0f) / lhs;
1040 } else if (val == -2.0f) { // NOLINT
1041 return ExprHandle(1.0f) / (lhs * lhs);
1042 }
1043 return pow(lhs, rhs);
1044 });
1045 });
1046
1047 RegisterNNCLoweringsFunction aten_fmod(
1048 {"aten::fmod.Scalar(Tensor self, Scalar other) -> (Tensor)",
1049 "aten::fmod.Tensor(Tensor self, Tensor other) -> (Tensor)"},
1050 [](const std::vector<ArgValue>& inputs,
1051 const std::vector<ExprHandle>& outputShape,
1052 const std::vector<ExprHandle>& outputStrides,
1053 const std::optional<ScalarType>& outputType,
1054 at::Device device) {
1055 return computeTwoOperand(
1056 "aten_fmod",
1057 inputs,
1058 outputShape,
1059 outputStrides,
1060 outputType,
1061 [](const ExprHandle& lhs, const ExprHandle& rhs) {
1062 return fmod(promoteHalfToFloat(lhs), promoteHalfToFloat(rhs));
1063 });
1064 });
1065
1066 RegisterNNCLoweringsFunction aten_lerp(
1067 {"aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> (Tensor)",
1068 "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> (Tensor)"},
1069 [](const std::vector<ArgValue>& inputs,
1070 const std::vector<ExprHandle>& outputShape,
1071 const std::vector<ExprHandle>& outputStrides,
1072 const std::optional<ScalarType>& outputType,
1073 at::Device device) {
1074 return computeThreeOperand(
1075 "aten_lerp",
1076 inputs,
1077 outputShape,
1078 outputStrides,
1079 outputType,
1080 [](const ExprHandle& a,
1081 const ExprHandle& end,
1082 const ExprHandle& weight) { return a + weight * (end - a); });
1083 });
1084
1085 RegisterNNCLoweringsFunction aten_remainder(
1086 {"aten::remainder.Scalar(Tensor self, Scalar other) -> (Tensor)",
1087 "aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> (Tensor)",
1088 "aten::remainder.Tensor(Tensor self, Tensor other) -> (Tensor)"},
1089 [](const std::vector<ArgValue>& inputs,
1090 const std::vector<ExprHandle>& outputShape,
1091 const std::vector<ExprHandle>& outputStrides,
1092 const std::optional<ScalarType>& outputType,
1093 at::Device device) {
1094 auto imodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) {
1095 return Mod::make(lhs, rhs);
1096 };
1097 auto fmodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) {
1098 auto lhs_t = promoteHalfToFloat(lhs);
1099 auto rhs_t = promoteHalfToFloat(rhs);
1100 return fmod((rhs_t + fmod(lhs_t, rhs_t)), rhs_t);
1101 };
1102 {
1103 auto const& shape =
1104 broadcastShapes(valueShape(inputs[0]), valueShape(inputs[1]));
1105 return Compute(
1106 "aten_remainder", shape, [&](const std::vector<VarHandle>& axes) {
1107 std::vector<ExprHandle> indices(axes.begin(), axes.end());
1108 std::vector<ExprHandle> exprInputs = {
1109 tensorOrConstant(inputs[0], indices),
1110 tensorOrConstant(inputs[1], indices),
1111 };
1112
1113 promoteInputs(exprInputs);
1114 bool allInt = true;
1115 for (auto& e : exprInputs) {
1116 if (e.dtype().is_floating_point()) {
1117 allInt = false;
1118 break;
1119 }
1120 }
1121 if (allInt) {
1122 return demoteOutput(
1123 imodImpl(exprInputs[0], exprInputs[1]), outputType);
1124 } else {
1125 return demoteOutput(
1126 fmodImpl(exprInputs[0], exprInputs[1]), outputType);
1127 }
1128 });
1129 }
1130 });
1131
1132 RegisterNNCLoweringsFunction prim_ConstantChunk(
1133 {"prim::ConstantChunk(...) -> (...)"}, computeChunk);
1134
1135 RegisterNNCLoweringsFunction aten_acos(
1136 {"aten::acos(Tensor self) -> (Tensor)"},
1137 [](const std::vector<ArgValue>& inputs,
1138 const std::vector<ExprHandle>& outputShape,
1139 const std::vector<ExprHandle>& outputStrides,
1140 const std::optional<ScalarType>& outputType,
1141 at::Device device) {
1142 return computeOneOperand(
1143 "aten_acos",
1144 inputs,
1145 outputShape,
1146 outputStrides,
1147 outputType,
1148 [](const ExprHandle& a) {
1149 return acos(promoteIntegerToDefaultType(a));
1150 });
1151 });
1152
1153 RegisterNNCLoweringsFunction aten_asin(
1154 {"aten::asin(Tensor self) -> (Tensor)"},
1155 [](const std::vector<ArgValue>& inputs,
1156 const std::vector<ExprHandle>& outputShape,
1157 const std::vector<ExprHandle>& outputStrides,
1158 const std::optional<ScalarType>& outputType,
1159 at::Device device) {
1160 return computeOneOperand(
1161 "aten_asin",
1162 inputs,
1163 outputShape,
1164 outputStrides,
1165 outputType,
1166 [](const ExprHandle& a) {
1167 return asin(promoteIntegerToDefaultType(a));
1168 });
1169 });
1170
1171 RegisterNNCLoweringsFunction aten_cosh(
1172 {"aten::cosh(Tensor self) -> (Tensor)"},
1173 [](const std::vector<ArgValue>& inputs,
1174 const std::vector<ExprHandle>& outputShape,
1175 const std::vector<ExprHandle>& outputStrides,
1176 const std::optional<ScalarType>& outputType,
1177 at::Device device) {
1178 return computeOneOperand(
1179 "aten_cosh",
1180 inputs,
1181 outputShape,
1182 outputStrides,
1183 outputType,
1184 [](const ExprHandle& a) {
1185 return cosh(promoteIntegerToDefaultType(a));
1186 });
1187 });
1188
1189 RegisterNNCLoweringsFunction aten_sinh(
1190 {"aten::sinh(Tensor self) -> (Tensor)"},
1191 [](const std::vector<ArgValue>& inputs,
1192 const std::vector<ExprHandle>& outputShape,
1193 const std::vector<ExprHandle>& outputStrides,
1194 const std::optional<ScalarType>& outputType,
1195 at::Device device) {
1196 return computeOneOperand(
1197 "aten_sinh",
1198 inputs,
1199 outputShape,
1200 outputStrides,
1201 outputType,
1202 [](const ExprHandle& a) {
1203 return sinh(promoteIntegerToDefaultType(a));
1204 });
1205 });
1206
1207 RegisterNNCLoweringsFunction aten_atan(
1208 {"aten::atan(Tensor self) -> (Tensor)"},
1209 [](const std::vector<ArgValue>& inputs,
1210 const std::vector<ExprHandle>& outputShape,
1211 const std::vector<ExprHandle>& outputStrides,
1212 const std::optional<ScalarType>& outputType,
1213 at::Device device) {
1214 return computeOneOperand(
1215 "aten_atan",
1216 inputs,
1217 outputShape,
1218 outputStrides,
1219 outputType,
1220 [](const ExprHandle& a) {
1221 return atan(promoteIntegerToDefaultType(a));
1222 });
1223 });
1224
1225 RegisterNNCLoweringsFunction aten_atan2(
1226 {"aten::atan2(Tensor self, Tensor other) -> (Tensor)"},
1227 [](const std::vector<ArgValue>& inputs,
1228 const std::vector<ExprHandle>& outputShape,
1229 const std::vector<ExprHandle>& outputStrides,
1230 const std::optional<ScalarType>& outputType,
1231 at::Device device) {
1232 return computeTwoOperand(
1233 "aten_atan2",
1234 inputs,
1235 outputShape,
1236 outputStrides,
1237 outputType,
1238 [](const ExprHandle& lhs, const ExprHandle& rhs) {
1239 return atan2(
1240 promoteIntegerToDefaultType(lhs),
1241 promoteIntegerToDefaultType(rhs));
1242 });
1243 });
1244
1245 RegisterNNCLoweringsFunction aten_tanh(
1246 {"aten::tanh(Tensor self) -> (Tensor)"},
1247 [](const std::vector<ArgValue>& inputs,
1248 const std::vector<ExprHandle>& outputShape,
1249 const std::vector<ExprHandle>& outputStrides,
1250 const std::optional<ScalarType>& outputType,
1251 at::Device device) {
1252 return computeOneOperand(
1253 "aten_tanh",
1254 inputs,
1255 outputShape,
1256 outputStrides,
1257 outputType,
1258 [](const ExprHandle& a) {
1259 return tanh(promoteIntegerToDefaultType(a));
1260 });
1261 });
1262
1263 RegisterNNCLoweringsFunction aten_hardtanh(
1264 {"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)"},
1265 [](const std::vector<ArgValue>& inputs,
1266 const std::vector<ExprHandle>& outputShape,
1267 const std::vector<ExprHandle>& outputStrides,
1268 const std::optional<ScalarType>& outputType,
1269 at::Device device) {
1270 return computeThreeOperand(
1271 "aten_hardtanh",
1272 inputs,
1273 outputShape,
1274 outputStrides,
1275 outputType,
1276 [](const ExprHandle& a,
1277 const ExprHandle& min_val,
1278 const ExprHandle& max_val) {
1279 auto mm = CompareSelect::make(a, min_val, min_val, a, kLT);
1280 return CompareSelect::make(mm, max_val, max_val, mm, kGT);
1281 });
1282 });
1283
1284 RegisterNNCLoweringsFunction aten_softplus(
1285 {"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> (Tensor)"},
1286 [](const std::vector<ArgValue>& inputs,
1287 const std::vector<ExprHandle>& outputShape,
1288 const std::vector<ExprHandle>& outputStrides,
1289 const std::optional<ScalarType>& outputType,
1290 at::Device device) {
1291 return computeThreeOperand(
1292 "aten_softplus",
1293 inputs,
1294 outputShape,
1295 outputStrides,
1296 outputType,
1297 [](const ExprHandle& a,
1298 const ExprHandle& beta,
1299 const ExprHandle& threshold) {
1300 auto beta_promoted = Cast::make(a.dtype(), beta);
1301 auto threshold_promoted = Cast::make(a.dtype(), threshold);
1302 auto beta_a = beta_promoted * a;
1303 return CompareSelect::make(
1304 beta_a,
1305 threshold_promoted,
1306 a,
1307 log1p(exp(beta_a)) / beta_promoted,
1308 kGT);
1309 });
1310 });
1311
1312 RegisterNNCLoweringsFunction aten_mish(
1313 {"aten::mish(Tensor self) -> (Tensor)"},
1314 [](const std::vector<ArgValue>& inputs,
1315 const std::vector<ExprHandle>& outputShape,
1316 const std::vector<ExprHandle>& outputStrides,
1317 const std::optional<ScalarType>& outputType,
1318 at::Device device) {
1319 return computeOneOperand(
1320 "aten_mish",
1321 inputs,
1322 outputShape,
1323 outputStrides,
1324 outputType,
1325 [](const ExprHandle& a) {
1326 auto default_type_a = promoteIntegerToDefaultType(a);
1327 return default_type_a * tanh(log1p(exp(default_type_a)));
1328 });
1329 });
1330
1331 RegisterNNCLoweringsFunction aten_elu(
1332 {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)"},
1333 [](const std::vector<ArgValue>& inputs,
1334 const std::vector<ExprHandle>& outputShape,
1335 const std::vector<ExprHandle>& outputStrides,
1336 const std::optional<ScalarType>& outputType,
1337 at::Device device) {
1338 return computeFourOperand(
1339 "aten_elu",
1340 inputs,
1341 outputShape,
1342 outputStrides,
1343 outputType,
1344 [](const ExprHandle& a,
1345 const ExprHandle& alpha,
1346 const ExprHandle& scale,
1347 const ExprHandle& input_scale) {
1348 auto zero = Cast::make(a.dtype(), 0);
1349 auto one = Cast::make(a.dtype(), 1);
1350
1351 auto poscoef = Cast::make(a.dtype(), scale);
1352 auto negiptcoef = Cast::make(a.dtype(), input_scale);
1353 auto negcoef = Cast::make(a.dtype(), alpha) * poscoef;
1354
1355 return CompareSelect::make(
1356 a,
1357 zero,
1358 a * poscoef,
1359 (exp(a * negiptcoef) - one) * negcoef,
1360 kGT);
1361 });
1362 });
1363
1364 RegisterNNCLoweringsFunction aten_hardsigmoid(
1365 {"aten::hardsigmoid(Tensor self) -> (Tensor)"},
1366 [](const std::vector<ArgValue>& inputs,
1367 const std::vector<ExprHandle>& outputShape,
1368 const std::vector<ExprHandle>& outputStrides,
1369 const std::optional<ScalarType>& outputType,
1370 at::Device device) {
1371 return computeOneOperand(
1372 "aten_hardsigmoid",
1373 inputs,
1374 outputShape,
1375 outputStrides,
1376 outputType,
1377 [](const ExprHandle& a) {
1378 auto zero = Cast::make(a.dtype(), 0.0);
1379 auto three = Cast::make(a.dtype(), 3.0);
1380 auto six = Cast::make(a.dtype(), 6.0);
1381 return clamp(zero, six, a + three) / six;
1382 });
1383 });
1384
1385 RegisterNNCLoweringsFunction aten_hardswish(
1386 {"aten::hardswish(Tensor self) -> (Tensor)"},
1387 [](const std::vector<ArgValue>& inputs,
1388 const std::vector<ExprHandle>& outputShape,
1389 const std::vector<ExprHandle>& outputStrides,
1390 const std::optional<ScalarType>& outputType,
1391 at::Device device) {
1392 return computeOneOperand(
1393 "aten_hardswish",
1394 inputs,
1395 outputShape,
1396 outputStrides,
1397 outputType,
1398 [](const ExprHandle& a) {
1399 // x * torch.clamp(x + 3.0, 0.0, 6.0) / 6.0
1400 auto zero = Cast::make(a.dtype(), 0.);
1401 auto three = Cast::make(a.dtype(), 3.);
1402 auto six = Cast::make(a.dtype(), 6.);
1403
1404 return a * clamp(zero, six, a + three) / six;
1405 });
1406 });
1407
1408 RegisterNNCLoweringsFunction aten_hardshrink(
1409 {"aten::hardshrink(Tensor self, Scalar lambd=0.5) -> (Tensor)"},
1410 [](const std::vector<ArgValue>& inputs,
1411 const std::vector<ExprHandle>& outputShape,
1412 const std::vector<ExprHandle>& outputStrides,
1413 const std::optional<ScalarType>& outputType,
1414 at::Device device) {
1415 return computeTwoOperand(
1416 "aten_hardshrink",
1417 inputs,
1418 outputShape,
1419 outputStrides,
1420 outputType,
1421 [](const ExprHandle& a, const ExprHandle& lambd) {
1422 auto pos_clambd = Cast::make(a.dtype(), lambd);
1423 auto neg_clambd =
1424 Cast::make(a.dtype(), ExprHandle(-0)) - pos_clambd;
1425 auto zero = Cast::make(a.dtype(), 0);
1426 auto mm = CompareSelect::make(a, neg_clambd, a, zero, kLT);
1427 return CompareSelect::make(a, pos_clambd, a, mm, kGT);
1428 });
1429 });
1430
1431 RegisterNNCLoweringsFunction aten_sqrt(
1432 {"aten::sqrt(Tensor self) -> (Tensor)"},
1433 [](const std::vector<ArgValue>& inputs,
1434 const std::vector<ExprHandle>& outputShape,
1435 const std::vector<ExprHandle>& outputStrides,
1436 const std::optional<ScalarType>& outputType,
1437 at::Device device) {
1438 return computeOneOperand(
1439 "aten_sqrt",
1440 inputs,
1441 outputShape,
1442 outputStrides,
1443 outputType,
1444 [](const ExprHandle& a) {
1445 return tensorexpr::sqrt(promoteIntegerToDefaultType(a));
1446 });
1447 });
1448
1449 RegisterNNCLoweringsFunction aten_rsqrt(
1450 {"aten::rsqrt(Tensor self) -> (Tensor)"},
1451 [](const std::vector<ArgValue>& inputs,
1452 const std::vector<ExprHandle>& outputShape,
1453 const std::vector<ExprHandle>& outputStrides,
1454 const std::optional<ScalarType>& outputType,
1455 at::Device device) {
1456 return computeOneOperand(
1457 "aten_rsqrt",
1458 inputs,
1459 outputShape,
1460 outputStrides,
1461 outputType,
1462 [](const ExprHandle& a) {
1463 return rsqrt(promoteIntegerToDefaultType(a));
1464 });
1465 });
1466
1467 RegisterNNCLoweringsFunction aten_abs(
1468 {"aten::abs(Tensor self) -> (Tensor)"},
1469 [](const std::vector<ArgValue>& inputs,
1470 const std::vector<ExprHandle>& outputShape,
1471 const std::vector<ExprHandle>& outputStrides,
1472 const std::optional<ScalarType>& outputType,
1473 at::Device device) {
1474 return computeOneOperand(
1475 "aten_abs",
1476 inputs,
1477 outputShape,
1478 outputStrides,
1479 outputType,
1480 [](const ExprHandle& a) {
1481 return tensorexpr::abs(promoteHalfToFloat(a));
1482 },
1483 kIntegralTypes | kFloatingPointTypes | kBoolType);
1484 });
1485
1486 RegisterNNCLoweringsFunction aten_sign(
1487 {"aten::sign(Tensor self) -> (Tensor)"},
1488 [](const std::vector<ArgValue>& inputs,
1489 const std::vector<ExprHandle>& outputShape,
1490 const std::vector<ExprHandle>& outputStrides,
1491 const std::optional<ScalarType>& outputType,
1492 at::Device device) { return computeSign(inputs, outputShape); });
1493
1494 RegisterNNCLoweringsFunction aten_ceil(
1495 {"aten::ceil(Tensor self) -> (Tensor)"},
1496 [](const std::vector<ArgValue>& inputs,
1497 const std::vector<ExprHandle>& outputShape,
1498 const std::vector<ExprHandle>& outputStrides,
1499 const std::optional<ScalarType>& outputType,
1500 at::Device device) {
1501 return computeOneOperand(
1502 "aten_ceil",
1503 inputs,
1504 outputShape,
1505 outputStrides,
1506 outputType,
1507 [](const ExprHandle& a) { return ceil(a); });
1508 });
1509
1510 RegisterNNCLoweringsFunction aten_floor(
1511 {"aten::floor(Tensor self) -> (Tensor)"},
1512 [](const std::vector<ArgValue>& inputs,
1513 const std::vector<ExprHandle>& outputShape,
1514 const std::vector<ExprHandle>& outputStrides,
1515 const std::optional<ScalarType>& outputType,
1516 at::Device device) {
1517 return computeOneOperand(
1518 "aten_floor",
1519 inputs,
1520 outputShape,
1521 outputStrides,
1522 outputType,
1523 [](const ExprHandle& a) { return floor(a); });
1524 });
1525
1526 RegisterNNCLoweringsFunction aten_round(
1527 {"aten::round(Tensor self) -> (Tensor)"},
1528 [](const std::vector<ArgValue>& inputs,
1529 const std::vector<ExprHandle>& outputShape,
1530 const std::vector<ExprHandle>& outputStrides,
1531 const std::optional<ScalarType>& outputType,
1532 at::Device device) {
1533 return computeOneOperand(
1534 "aten_round",
1535 inputs,
1536 outputShape,
1537 outputStrides,
1538 outputType,
1539 [](const ExprHandle& a) { return round(a); });
1540 });
1541
1542 RegisterNNCLoweringsFunction aten_trunc(
1543 {"aten::trunc(Tensor self) -> (Tensor)"},
1544 [](const std::vector<ArgValue>& inputs,
1545 const std::vector<ExprHandle>& outputShape,
1546 const std::vector<ExprHandle>& outputStrides,
1547 const std::optional<ScalarType>& outputType,
1548 at::Device device) {
1549 return computeOneOperand(
1550 "aten_trunc",
1551 inputs,
1552 outputShape,
1553 outputStrides,
1554 outputType,
1555 [](const ExprHandle& a) { return trunc(a); });
1556 });
1557
1558 RegisterNNCLoweringsFunction aten__cast_Float(
1559 {"aten::_cast_Float(Tensor self, bool non_blocking=False) -> (Tensor)"},
1560 [](const std::vector<ArgValue>& inputs,
1561 const std::vector<ExprHandle>& outputShape,
1562 const std::vector<ExprHandle>& outputStrides,
1563 const std::optional<ScalarType>& outputType,
1564 at::Device device) {
1565 return computeOneOperand(
1566 "aten_cast_float",
1567 inputs,
1568 outputShape,
1569 outputStrides,
1570 outputType,
1571 [](const ExprHandle& a) { return cast<float>(a); });
1572 });
1573
1574 RegisterNNCLoweringsFunction aten_to(
1575 {"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
1576 "aten::to.dtype_layout(Tensor(a) self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
1577 "aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
1578 "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
1579 "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
1580 "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)",
1581 "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)"},
1582 [](const std::vector<ArgValue>& inputs,
1583 const std::vector<ExprHandle>& outputShape,
1584 const std::vector<ExprHandle>& outputStrides,
1585 const std::optional<ScalarType>& outputType,
1586 at::Device device) {
1587 // see handling of aten::to in tensorexpr_fuser.cpp for why we only
1588 // need to handle the first input
1589 return computeOneOperand(
1590 "aten_to",
1591 {inputs[0]},
1592 outputShape,
1593 outputStrides,
1594 outputType,
1595 [outputType](const ExprHandle& a) {
1596 TORCH_INTERNAL_ASSERT(
1597 outputType, buildErrorMessage("Output type is null."));
1598 return Cast::make(ToDtype(*outputType), a);
1599 });
1600 });
1601
1602 RegisterNNCLoweringsFunction aten_threshold(
1603 {"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> (Tensor)"},
1604 [](const std::vector<ArgValue>& inputs,
1605 const std::vector<ExprHandle>& outputShape,
1606 const std::vector<ExprHandle>& outputStrides,
1607 const std::optional<ScalarType>& outputType,
1608 at::Device device) {
1609 return computeThreeOperand(
1610 "aten_threshold",
1611 inputs,
1612 outputShape,
1613 outputStrides,
1614 outputType,
1615 [](const ExprHandle& a,
1616 const ExprHandle& threshold,
1617 const ExprHandle& value) {
1618 return ifThenElse(
1619 CompareSelect::make(a, threshold, kLE), value, a);
1620 });
1621 });
1622
1623 RegisterNNCLoweringsFunction aten_where(
1624 {"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> (Tensor)",
1625 "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> (Tensor)",
1626 "aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
1627 "aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor"},
1628 [](const std::vector<ArgValue>& inputs,
1629 const std::vector<ExprHandle>& outputShape,
1630 const std::vector<ExprHandle>& outputStrides,
1631 const std::optional<ScalarType>& outputType,
1632 at::Device device) {
1633 return computeConditionWithTwoOperand(
1634 "aten_where",
1635 inputs,
1636 outputShape,
1637 outputStrides,
1638 outputType,
1639 [](const ExprHandle& a0,
1640 const ExprHandle& a1,
1641 const ExprHandle& a2) { return ifThenElse(a0, a1, a2); });
1642 });
1643
1644 RegisterNNCLoweringsFunction aten_frac(
1645 {"aten::frac(Tensor self) -> (Tensor)"},
1646 [](const std::vector<ArgValue>& inputs,
1647 const std::vector<ExprHandle>& outputShape,
1648 const std::vector<ExprHandle>& outputStrides,
1649 const std::optional<ScalarType>& outputType,
1650 at::Device device) {
1651 return computeOneOperand(
1652 "aten_frac",
1653 inputs,
1654 outputShape,
1655 outputStrides,
1656 outputType,
1657 [](const ExprHandle& a) {
1658 auto aa = promoteHalfToFloat(a);
1659 return aa - floor(aa);
1660 },
1661 kFloatingPointTypes);
1662 });
1663
1664 RegisterNNCLoweringsFunction aten_lgamma(
1665 {"aten::lgamma(Tensor self) -> (Tensor)"},
1666 [](const std::vector<ArgValue>& inputs,
1667 const std::vector<ExprHandle>& outputShape,
1668 const std::vector<ExprHandle>& outputStrides,
1669 const std::optional<ScalarType>& outputType,
1670 at::Device device) {
1671 return computeOneOperand(
1672 "aten_lgamma",
1673 inputs,
1674 outputShape,
1675 outputStrides,
1676 outputType,
1677 [](const ExprHandle& a) {
1678 return lgamma(promoteIntegerToDefaultType(a));
1679 });
1680 });
1681
1682 // TODO: convert to schema, add a test
1683 // RegisterNNCLoweringsFunction aten_rand_like(
1684 // {"aten::rand_like"},
1685 // [](const std::vector<ArgValue>& inputs,
1686 // const std::vector<ExprHandle>& outputShape,
1687 // const std::optional<ScalarType>& outputType,
1688 // at::Device device) {
1689 // return computeOneOperand(
1690 // "aten_rand_like",
1691 // inputs,
1692 // outputShape,
1693 // outputType,
1694 // [](const ExprHandle& a) {
1695 // return Intrinsics::make(IntrinsicsOp::kRand, a.dtype());
1696 // });
1697 // });
1698
1699 // TODO: convert to schema, add a test
1700 // RegisterNNCLoweringsFunction aten_slice(
1701 // {"aten::slice"},
1702 // [](const std::vector<ArgValue>& inputs,
1703 // const std::vector<ExprHandle>& outputShape,
1704 // const std::optional<ScalarType>& outputType,
1705 // at::Device device) {
1706 // return Compute(
1707 // "aten_slice",
1708 // outputShape,
1709 // [&](const std::vector<VarHandle>& axes) {
1710 // int64_t dim =
1711 // at::maybe_wrap_dim(std::get<int64_t>(inputs[1]),
1712 // axes.size());
1713 // ExprHandle start = constant(inputs[2]);
1714 // ExprHandle stride = constant(inputs[4]);
1715
1716 // std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
1717 // newAxes[dim] = stride * newAxes[dim] + start;
1718 // return tensorOrConstant(inputs[0], newAxes);
1719 // });
1720 // });
1721 RegisterNNCLoweringsFunction aten_unsqueeze(
1722 {"aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))"},
1723 [](const std::vector<ArgValue>& inputs,
1724 const std::vector<ExprHandle>& outputShape,
1725 const std::vector<ExprHandle>& outputStrides,
1726 const std::optional<ScalarType>& outputType,
1727 at::Device device) {
1728 return Compute(
1729 "aten_unsqueeze",
1730 outputShape,
1731 outputStrides,
1732 [&](const std::vector<VarHandle>& axes) {
1733 int64_t dim = std::get<int64_t>(inputs[1]);
1734 if (dim < 0) {
1735 if (axes.empty()) {
1736 throw malformed_input("axes are zero handling unsqueeze");
1737 }
1738 dim += axes.size();
1739 }
1740 // To construct an expression for an 'unsqueezed' tensor we need
1741 // to drop the DIM-th axis, i.e.
1742 // unsqueezed_v[i,j,k,l] = v[i,j,l] # dim = 2 - drop index 'k'
1743 // 0 1 2 3
1744 std::vector<ExprHandle> indices;
1745 int64_t i = 0;
1746 for (const auto& a : axes) {
1747 if (i++ != dim) {
1748 indices.emplace_back(a.node());
1749 }
1750 }
1751
1752 return broadcast(std::get<BufHandle>(inputs[0]), indices);
1753 });
1754 });
1755 RegisterNNCLoweringsFunction aten_t(
1756 {"aten::t(Tensor(a) self) -> (Tensor(a))"},
1757 [](const std::vector<ArgValue>& inputs,
1758 const std::vector<ExprHandle>& outputShape,
1759 const std::vector<ExprHandle>& outputStrides,
1760 const std::optional<ScalarType>& outputType,
1761 at::Device device) {
1762 return computeTranspose(
1763 {inputs[0], (int64_t)1, (int64_t)0},
1764 outputShape,
1765 outputStrides,
1766 outputType,
1767 device);
1768 });
1769 RegisterNNCLoweringsFunction aten_transpose(
1770 {"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))"},
1771 computeTranspose);
1772 RegisterNNCLoweringsFunction aten_permute(
1773 {"aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))"},
1774 [](const std::vector<ArgValue>& inputs,
1775 const std::vector<ExprHandle>& outputShape,
1776 const std::vector<ExprHandle>& outputStrides,
1777 const std::optional<ScalarType>& outputType,
1778 at::Device device) {
1779 auto A = std::get<BufHandle>(inputs[0]);
1780 // Trivial case of 0-dim tensors: just a copy of the input
1781 if (A.ndim() == 0) {
1782 auto tensor = Compute(
1783 "aten_permute",
1784 outputShape,
1785 outputStrides,
1786 [&](const std::vector<VarHandle>& axes) {
1787 std::vector<ExprHandle> empty_indices;
1788 return A.load(empty_indices);
1789 });
1790 if (A.node()->qscale()) {
1791 tensor.buf()->set_qscale(A.node()->qscale());
1792 tensor.buf()->set_qzero(A.node()->qzero());
1793 }
1794 return tensor;
1795 }
1796 auto permute_dims = std::get<IntList>(inputs[1]);
1797 auto tensor = Compute(
1798 "aten_permute",
1799 outputShape,
1800 [&](const std::vector<VarHandle>& axes) {
1801 std::vector<VarHandle> new_axes;
1802 new_axes.resize(axes.size());
1803 assert(permute_dims.size() == axes.size());
1804 for (unsigned i = 0; i < axes.size(); i++) {
1805 auto new_dim = at::maybe_wrap_dim(permute_dims[i], A.ndim());
1806 new_axes[new_dim] = axes[i];
1807 }
1808 return A.load(new_axes);
1809 });
1810 if (A.node()->qscale()) {
1811 tensor.buf()->set_qscale(A.node()->qscale());
1812 tensor.buf()->set_qzero(A.node()->qzero());
1813 }
1814 return tensor;
1815 });
1816 RegisterNNCLoweringsFunction aten_expand(
1817 {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
1818 "aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))"},
1819 computeExpand);
1820
1821 // TODO: add a test
1822 RegisterNNCLoweringsFunction aten_flatten(
1823 {"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> (Tensor(a))"},
1824 computeFlatten);
1825 RegisterNNCLoweringsFunction aten_view(
1826 {"aten::reshape(Tensor(a) self, int[] shape) -> (Tensor(a))",
1827 "aten::reshape_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
1828 "aten::view(Tensor(a) self, int[] size) -> (Tensor(a))",
1829 "aten::view_as(Tensor(a) self, Tensor other) -> (Tensor(a))"},
1830 computeReshape);
1831
1832 // aten::mm is a subset of aten::matmul where both inputs are rank 2
1833 RegisterNNCLoweringsFunction aten_matmul(
1834 {"aten::mm(Tensor self, Tensor mat2) -> (Tensor)",
1835 "aten::matmul(Tensor self, Tensor other) -> (Tensor)"},
1836 computeMatmul);
1837
1838 RegisterNNCLoweringsFunction aten_cat(
1839 {"aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)"}, computeCat);
1840
1841 RegisterNNCLoweringsFunction aten_sum(
1842 {"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)",
1843 "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
1844 computeSum);
1845
1846 RegisterNNCLoweringsFunction aten_softmax(
1847 {"aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)"},
1848 [](const std::vector<ArgValue>& inputs,
1849 const std::vector<ExprHandle>& outputShape,
1850 const std::vector<ExprHandle>& outputStrides,
1851 const std::optional<ScalarType>& outputType,
1852 at::Device device) {
1853 return computeSoftmax(inputs, outputShape, outputStrides, false);
1854 });
1855
1856 RegisterNNCLoweringsFunction aten_log_softmax(
1857 {"aten::log_softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)"},
1858 [](const std::vector<ArgValue>& inputs,
1859 const std::vector<ExprHandle>& outputShape,
1860 const std::vector<ExprHandle>& outputStrides,
1861 const std::optional<ScalarType>& outputType,
1862 at::Device device) {
1863 return computeSoftmax(inputs, outputShape, outputStrides, true);
1864 });
1865
1866 RegisterNNCLoweringsFunction aten_conv1d(
1867 {"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> (Tensor)"},
1868 computeConv1d);
1869 RegisterNNCLoweringsFunction aten_conv2d(
1870 {"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=[1, 1], int[2] padding=[0, 0], int[2] dilation=[1, 1], int groups=1) -> (Tensor)"},
1871 computeConv2d);
1872
1873 RegisterNNCLoweringsFunction aten_addmm(
1874 {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)"},
1875 computeAddMM);
1876
1877 RegisterNNCLoweringsFunction aten_mean(
1878 {"aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)",
1879 "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
1880 computeMean);
1881 RegisterNNCLoweringsFunction aten_max_reduction(
1882 {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"},
1883 computeMax);
1884
1885 RegisterNNCLoweringsFunction aten_adaptive_avg_pool2d(
1886 {"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)"},
1887 computeAdaptiveAvgPool2d);
1888
1889 RegisterNNCLoweringsFunction aten_add(
1890 {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
1891 "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)"},
1892 [](const std::vector<ArgValue>& inputs,
1893 const std::vector<ExprHandle>& outputShape,
1894 const std::vector<ExprHandle>& outputStrides,
1895 const std::optional<ScalarType>& outputType,
1896 at::Device device) {
1897 auto add_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) {
1898 return boolToInteger(lhs) + boolToInteger(rhs);
1899 };
1900 TORCH_INTERNAL_ASSERT(
1901 inputs.size() == 2 || inputs.size() == 3,
1902 buildErrorMessage("Invalid number of input operands"));
1903 return (inputs.size() > 2) ? computeTwoOperandWithAlpha(
1904 "aten_add",
1905 inputs,
1906 outputShape,
1907 outputStrides,
1908 outputType,
1909 add_lambda)
1910 : computeTwoOperand(
1911 "aten_add",
1912 inputs,
1913 outputShape,
1914 outputStrides,
1915 outputType,
1916 add_lambda);
1917 });
1918 RegisterNNCLoweringsFunction aten_embedding(
1919 {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor"},
1920 computeEmbedding);
1921
1922 #define NNC_QUANTIZATION_EXPR_QUANT 1
1923 #define NNC_QUANTIZATION_EXPR_DEQUANT 1
1924
1925 RegisterNNCLoweringsFunction aten_quantize_per_tensor(
1926 {"aten::quantize_per_tensor(Tensor self, float scale, int zero_point, int dtype) -> (Tensor)",
1927 "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int dtype) -> (Tensor)",
1928 "aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, int dtype) -> (Tensor[])"},
1929 #if NNC_QUANTIZATION_EXPR_QUANT == 1
1930 computeQuantizePerTensor
1931 #else
1932 computeQuantizePerTensorExternalCall
1933 #endif
1934 );
1935
1936 RegisterNNCLoweringsFunction aten_dequantize(
1937 {"aten::dequantize.self(Tensor self) -> (Tensor)"},
1938 #if NNC_QUANTIZATION_EXPR_DEQUANT == 1
1939 computeDequantize
1940 #else
1941 computeDequantizeExternalCall
1942 #endif
1943 );
1944 RegisterNNCLoweringsFunction quantized_conv1d(
1945 {"quantized::conv1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)"},
1946 computeQuantizedConv1d);
1947
1948 RegisterNNCLoweringsFunction quantized_conv2d(
1949 {"quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)"},
1950 computeQuantizedConv2d);
1951
1952 RegisterNNCLoweringsFunction quantized_conv2d_relu(
1953 {"quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)"},
1954 computeQuantizedConv2dRelu);
1955
1956 RegisterNNCLoweringsFunction quantized_linear(
1957 {"quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)"},
1958 computeQuantizedLinear);
1959
1960 RegisterNNCLoweringsFunction quantized_linear_relu(
1961 {"quantized::linear_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)"},
1962 computeQuantizedLinear);
1963
1964 RegisterNNCLoweringsFunction quantized_add(
1965 {"quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> (Tensor qc)"},
1966 computeQuantizedAdd);
1967
1968 RegisterNNCLoweringsFunction quantized_mul(
1969 {"quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point) -> (Tensor qc)"},
1970 computeQuantizedMul);
1971
1972 RegisterNNCLoweringsFunction quantized_mul_scalar(
1973 {"quantized::mul.Scalar(Tensor qa, Scalar b) -> (Tensor qc)"},
1974 computeQuantizedMulScalar);
1975
1976 RegisterNNCLoweringsFunction quantized_conv2d_prepack(
1977 {"quantized::conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> (__torch__.torch.classes.quantized.Conv2dPackedParamsBase)"},
1978 computeQuantizedConv2dPrepack);
1979
1980 RegisterNNCLoweringsFunction quantized_cat(
1981 {"quantized::cat(Tensor[] qx, int dim, float? scale, int? zero_point) -> (Tensor)"},
1982 computeQuantizedCat);
1983
1984 RegisterNNCLoweringsFunction aten_upsample_nearest2d(
1985 {"aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)"},
1986 computeUpsampleNearest2dExternalCall);
1987
1988 return 0;
1989 }
1990 } // namespace
1991
getStandardLoweringFor(const std::string & schema_str)1992 NNCLoweringFunction getStandardLoweringFor(const std::string& schema_str) {
1993 C10_UNUSED static const int once = nnc_lowerings_lazy_registration();
1994 const auto& lowerings = getNNCLoweringRegistry();
1995 if (auto l = lowerings.find(parseSchema(schema_str))) {
1996 return *l;
1997 }
1998 return nullptr;
1999 }
2000
2001 } // namespace torch::jit::tensorexpr
2002