xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/lowerings.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/function_schema_parser.h>
2 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3 #include <torch/csrc/jit/tensorexpr/lowerings.h>
4 #include <torch/csrc/jit/tensorexpr/operators/operators.h>
5 
6 #include <ATen/native/Activation.h>
7 #include <ATen/native/mkldnn/Common.h>
8 
9 namespace torch::jit::tensorexpr {
10 
getNNCLoweringRegistry()11 FunctionSchemaMap<NNCLoweringFunction>& getNNCLoweringRegistry() {
12   static FunctionSchemaMap<NNCLoweringFunction> lowering_registry_;
13   return lowering_registry_;
14 }
15 
RegisterNNCLoweringsFunction(const std::vector<std::string> & schemas,const NNCLoweringFunction & fn)16 RegisterNNCLoweringsFunction::RegisterNNCLoweringsFunction(
17     const std::vector<std::string>& schemas,
18     const NNCLoweringFunction& fn) {
19   for (const auto& schema_str : schemas) {
20     getNNCLoweringRegistry().insert(parseSchema(schema_str), fn);
21   }
22 }
23 
24 namespace {
25 // NOLINTNEXTLINE
nnc_lowerings_lazy_registration()26 int nnc_lowerings_lazy_registration() {
27   RegisterNNCLoweringsFunction aten_dropout(
28       {"aten::dropout(Tensor input, float p, bool train) -> (Tensor)"},
29       computeNoop);
30   RegisterNNCLoweringsFunction aten_contiguous(
31       {"aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> (Tensor(a))"},
32       computeNoop);
33 
34 #ifdef USE_XNNPACK
35   // TODO: add a test
36   RegisterNNCLoweringsFunction prepacked_conv2d_clamp_run(
37       {"prepacked::conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> (Tensor Y)"},
38       computePrepackedConv2dClampRun);
39 
40   // TODO: add a test
41   RegisterNNCLoweringsFunction prepacked_linear_clamp_run(
42       {"prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> (Tensor Y)"},
43       computePrepackedLinearClampRun);
44 #endif
45 
46 #if AT_MKLDNN_ENABLED()
47   RegisterNNCLoweringsFunction mkldnn_prepacked_conv2d_run(
48       {"mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> (Tensor Y)"},
49       computeMkldnnPrepackedConvRun);
50 #endif // AT_MKLDNN_ENABLED()
51 
52   RegisterNNCLoweringsFunction aten_sub(
53       {"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
54        "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)"},
55       [](const std::vector<ArgValue>& inputs,
56          const std::vector<ExprHandle>& outputShape,
57          const std::vector<ExprHandle>& outputStrides,
58          const std::optional<ScalarType>& outputType,
59          at::Device device) {
60         auto sub_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) {
61           // NB: sub isn't supported on boolean, no need to promote to integer.
62           return lhs - rhs;
63         };
64         TORCH_INTERNAL_ASSERT(
65             inputs.size() == 2 || inputs.size() == 3,
66             buildErrorMessage("Invalid number of input operands"));
67         return (inputs.size() > 2) ? computeTwoOperandWithAlpha(
68                                          "aten_sub",
69                                          inputs,
70                                          outputShape,
71                                          outputStrides,
72                                          outputType,
73                                          sub_lambda)
74                                    : computeTwoOperand(
75                                          "aten_sub",
76                                          inputs,
77                                          outputShape,
78                                          outputStrides,
79                                          outputType,
80                                          sub_lambda);
81       });
82 
83   RegisterNNCLoweringsFunction aten_mul(
84       {"aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)",
85        "aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)"},
86       [](const std::vector<ArgValue>& inputs,
87          const std::vector<ExprHandle>& outputShape,
88          const std::vector<ExprHandle>& outputStrides,
89          const std::optional<ScalarType>& outputType,
90          at::Device device) {
91         return computeTwoOperand(
92             "aten_mul",
93             inputs,
94             outputShape,
95             outputStrides,
96             outputType,
97             [](const ExprHandle& lhs, const ExprHandle& rhs) {
98               return boolToInteger(lhs) * boolToInteger(rhs);
99             });
100       });
101 
102 #define DEFINE_BINARY_SCALAR_OP_LOWERING(op_name, op)                     \
103   RegisterNNCLoweringsFunction aten_##op_name##_scalar(                   \
104       {"aten::" #op_name ".int(int a, int b) -> (int)",                   \
105        "aten::" #op_name ".int_float(int a, float b) -> (float)",         \
106        "aten::" #op_name ".float_int(float a, int b) -> (float)",         \
107        "aten::" #op_name ".float(float a, float b) -> (float)"},          \
108       [](const std::vector<ArgValue>& inputs,                             \
109          const std::vector<ExprHandle>& outputShape,                      \
110          const std::vector<ExprHandle>& outputStrides,                    \
111          const std::optional<ScalarType>& outputType,                     \
112          at::Device device) {                                             \
113         return computeScalar(                                             \
114             "aten_#op_name",                                              \
115             inputs,                                                       \
116             outputShape,                                                  \
117             outputStrides,                                                \
118             outputType,                                                   \
119             [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
120       });
121   DEFINE_BINARY_SCALAR_OP_LOWERING(mul, a * b)
122   DEFINE_BINARY_SCALAR_OP_LOWERING(add, a + b)
123   DEFINE_BINARY_SCALAR_OP_LOWERING(sub, a - b)
124 #undef DEFINE_BINARY_SCALAR_OP_LOWERING
125   RegisterNNCLoweringsFunction aten_div_scalar(
126       {"aten::div(Scalar a, Scalar b) -> (float)",
127        "aten::div.int(int a, int b) -> (float)",
128        "aten::div.int_float(int a, float b) -> (float)",
129        "aten::div.float_int(float a, int b) -> (float)",
130        "aten::div.float(float a, float b) -> (float)"},
131       [](const std::vector<ArgValue>& inputs,
132          const std::vector<ExprHandle>& outputShape,
133          const std::vector<ExprHandle>& outputStrides,
134          const std::optional<ScalarType>& outputType,
135          at::Device device) {
136         return computeScalar(
137             "aten_div",
138             inputs,
139             outputShape,
140             outputStrides,
141             outputType,
142             [](const ExprHandle& a, const ExprHandle& b) {
143               return promoteIntegerToDefaultType(a) /
144                   promoteIntegerToDefaultType(b);
145             });
146       });
147 
148 #define DEFINE_COMPARISON_SCALAR_OP_LOWERING(op_name, op)                 \
149   RegisterNNCLoweringsFunction aten_##op_name##_scalar(                   \
150       {"aten::" #op_name ".bool(bool a, bool b) -> (bool)",               \
151        "aten::" #op_name ".int(int a, int b) -> (bool)",                  \
152        "aten::" #op_name ".int_float(int a, float b) -> (bool)",          \
153        "aten::" #op_name ".float_int(float a, int b) -> (bool)",          \
154        "aten::" #op_name ".float(float a, float b) -> (bool)"},           \
155       [](const std::vector<ArgValue>& inputs,                             \
156          const std::vector<ExprHandle>& outputShape,                      \
157          const std::vector<ExprHandle>& outputStrides,                    \
158          const std::optional<ScalarType>& outputType,                     \
159          at::Device device) {                                             \
160         return computeScalar(                                             \
161             "aten_#op_name",                                              \
162             inputs,                                                       \
163             outputShape,                                                  \
164             outputStrides,                                                \
165             outputType,                                                   \
166             [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
167       });
168   DEFINE_COMPARISON_SCALAR_OP_LOWERING(lt, cast<bool>(a < b))
169   DEFINE_COMPARISON_SCALAR_OP_LOWERING(le, cast<bool>(a <= b))
170   DEFINE_COMPARISON_SCALAR_OP_LOWERING(eq, cast<bool>(a == b))
171   DEFINE_COMPARISON_SCALAR_OP_LOWERING(ne, cast<bool>(a != b))
172   DEFINE_COMPARISON_SCALAR_OP_LOWERING(gt, cast<bool>(a > b))
173   DEFINE_COMPARISON_SCALAR_OP_LOWERING(ge, cast<bool>(a >= b))
174 #undef DEFINE_COMPARISON_SCALAR_OP_LOWERING
175 
176 #define DEFINE_BITWISE_SCALAR_OP_LOWERING(op_name, op)                    \
177   RegisterNNCLoweringsFunction aten_##op_name##_int_scalar(               \
178       {"aten::" #op_name ".int(int a, int b) -> (int)"},                  \
179       [](const std::vector<ArgValue>& inputs,                             \
180          const std::vector<ExprHandle>& outputShape,                      \
181          const std::vector<ExprHandle>& outputStrides,                    \
182          const std::optional<ScalarType>& outputType,                     \
183          at::Device device) {                                             \
184         return computeScalar(                                             \
185             "aten_#op_name",                                              \
186             inputs,                                                       \
187             outputShape,                                                  \
188             outputStrides,                                                \
189             outputType,                                                   \
190             [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
191       });
192   DEFINE_BITWISE_SCALAR_OP_LOWERING(
193       __and__, boolToInteger(a) & boolToInteger(b))
194   DEFINE_BITWISE_SCALAR_OP_LOWERING(__or__, boolToInteger(a) | boolToInteger(b))
195   DEFINE_BITWISE_SCALAR_OP_LOWERING(
196       __xor__, boolToInteger(a) ^ boolToInteger(b))
197   DEFINE_BITWISE_SCALAR_OP_LOWERING(__lshift__, a << b)
198   DEFINE_BITWISE_SCALAR_OP_LOWERING(__rshift__, a >> b)
199 #undef DEFINE_BITWISE_SCALAR_OP_LOWERING
200 
201 #define DEFINE_LOGICAL_SCALAR_OP_LOWERING(op_name, op)                    \
202   RegisterNNCLoweringsFunction aten_##op_name##_bool_scalar(              \
203       {"aten::" #op_name ".bool(bool a, bool b) -> (bool)"},              \
204       [](const std::vector<ArgValue>& inputs,                             \
205          const std::vector<ExprHandle>& outputShape,                      \
206          const std::vector<ExprHandle>& outputStrides,                    \
207          const std::optional<ScalarType>& outputType,                     \
208          at::Device device) {                                             \
209         return computeScalar(                                             \
210             "aten_#op_name",                                              \
211             inputs,                                                       \
212             outputShape,                                                  \
213             outputStrides,                                                \
214             outputType,                                                   \
215             [](const ExprHandle& a, const ExprHandle& b) { return op; }); \
216       });
217   DEFINE_LOGICAL_SCALAR_OP_LOWERING(__and__, a && b)
218   DEFINE_LOGICAL_SCALAR_OP_LOWERING(__or__, a || b)
219   DEFINE_LOGICAL_SCALAR_OP_LOWERING(__xor__, a != b)
220 #undef DEFINE_LOGICAL_SCALAR_OP_LOWERING
221 
222   RegisterNNCLoweringsFunction aten_div(
223       {"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)",
224        "aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)"},
225       [](const std::vector<ArgValue>& inputs,
226          const std::vector<ExprHandle>& outputShape,
227          const std::vector<ExprHandle>& outputStrides,
228          const std::optional<ScalarType>& outputType,
229          at::Device device) {
230         return computeTwoOperand(
231             "aten_div",
232             inputs,
233             outputShape,
234             outputStrides,
235             outputType,
236             [](const ExprHandle& lhs, const ExprHandle& rhs) {
237               return promoteIntegerToDefaultType(lhs) /
238                   promoteIntegerToDefaultType(rhs);
239             });
240       });
241 
242   RegisterNNCLoweringsFunction aten___and__(
243       {"aten::__and__.Scalar(Tensor self, Scalar other) -> (Tensor)",
244        "aten::__and__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
245       [](const std::vector<ArgValue>& inputs,
246          const std::vector<ExprHandle>& outputShape,
247          const std::vector<ExprHandle>& outputStrides,
248          const std::optional<ScalarType>& outputType,
249          at::Device device) {
250         return computeTwoOperand(
251             "aten_and",
252             inputs,
253             outputShape,
254             outputStrides,
255             outputType,
256             [](const ExprHandle& lhs, const ExprHandle& rhs) {
257               return boolToInteger(lhs) & boolToInteger(rhs);
258             });
259       });
260 
261   RegisterNNCLoweringsFunction aten___or__(
262       {"aten::__or__.Scalar(Tensor self, Scalar other) -> (Tensor)",
263        "aten::__or__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
264       [](const std::vector<ArgValue>& inputs,
265          const std::vector<ExprHandle>& outputShape,
266          const std::vector<ExprHandle>& outputStrides,
267          const std::optional<ScalarType>& outputType,
268          at::Device device) {
269         return computeTwoOperand(
270             "aten_or",
271             inputs,
272             outputShape,
273             outputStrides,
274             outputType,
275             [](const ExprHandle& lhs, const ExprHandle& rhs) {
276               return boolToInteger(lhs) | boolToInteger(rhs);
277             });
278       });
279 
280   RegisterNNCLoweringsFunction aten___xor__(
281       {"aten::__xor__.Scalar(Tensor self, Scalar other) -> (Tensor)",
282        "aten::__xor__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
283       [](const std::vector<ArgValue>& inputs,
284          const std::vector<ExprHandle>& outputShape,
285          const std::vector<ExprHandle>& outputStrides,
286          const std::optional<ScalarType>& outputType,
287          at::Device device) {
288         return computeTwoOperand(
289             "aten_xor",
290             inputs,
291             outputShape,
292             outputStrides,
293             outputType,
294             [](const ExprHandle& lhs, const ExprHandle& rhs) {
295               return boolToInteger(lhs) ^ boolToInteger(rhs);
296             });
297       });
298 
299   RegisterNNCLoweringsFunction aten___lshift__(
300       {"aten::__lshift__.Scalar(Tensor self, Scalar other) -> (Tensor)",
301        "aten::__lshift__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
302       [](const std::vector<ArgValue>& inputs,
303          const std::vector<ExprHandle>& outputShape,
304          const std::vector<ExprHandle>& outputStrides,
305          const std::optional<ScalarType>& outputType,
306          at::Device device) {
307         return computeTwoOperand(
308             "aten_lshift",
309             inputs,
310             outputShape,
311             outputStrides,
312             outputType,
313             [](const ExprHandle& lhs, const ExprHandle& rhs) {
314               return lhs << rhs;
315             });
316       });
317 
318   RegisterNNCLoweringsFunction aten___rshift__(
319       {"aten::__rshift__.Scalar(Tensor self, Scalar other) -> (Tensor)",
320        "aten::__rshift__.Tensor(Tensor self, Tensor other) -> (Tensor)"},
321       [](const std::vector<ArgValue>& inputs,
322          const std::vector<ExprHandle>& outputShape,
323          const std::vector<ExprHandle>& outputStrides,
324          const std::optional<ScalarType>& outputType,
325          at::Device device) {
326         return computeTwoOperand(
327             "aten_rshift",
328             inputs,
329             outputShape,
330             outputStrides,
331             outputType,
332             [](const ExprHandle& lhs, const ExprHandle& rhs) {
333               return lhs >> rhs;
334             });
335       });
336 
337   RegisterNNCLoweringsFunction aten_eq(
338       {"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)",
339        "aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)"},
340       [](const std::vector<ArgValue>& inputs,
341          const std::vector<ExprHandle>& outputShape,
342          const std::vector<ExprHandle>& outputStrides,
343          const std::optional<ScalarType>& outputType,
344          at::Device device) {
345         return computeTwoOperand(
346             "aten_eq",
347             inputs,
348             outputShape,
349             outputStrides,
350             outputType,
351             [](const ExprHandle& lhs, const ExprHandle& rhs) {
352               return cast<bool>(lhs == rhs);
353             });
354       });
355 
356   RegisterNNCLoweringsFunction aten_ne(
357       {"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
358        "aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)"},
359       [](const std::vector<ArgValue>& inputs,
360          const std::vector<ExprHandle>& outputShape,
361          const std::vector<ExprHandle>& outputStrides,
362          const std::optional<ScalarType>& outputType,
363          at::Device device) {
364         return computeTwoOperand(
365             "aten_ne",
366             inputs,
367             outputShape,
368             outputStrides,
369             outputType,
370             [](const ExprHandle& lhs, const ExprHandle& rhs) {
371               return cast<bool>(lhs != rhs);
372             });
373       });
374 
375   RegisterNNCLoweringsFunction aten_ge(
376       {"aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)",
377        "aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)"},
378       [](const std::vector<ArgValue>& inputs,
379          const std::vector<ExprHandle>& outputShape,
380          const std::vector<ExprHandle>& outputStrides,
381          const std::optional<ScalarType>& outputType,
382          at::Device device) {
383         return computeTwoOperand(
384             "aten_ge",
385             inputs,
386             outputShape,
387             outputStrides,
388             outputType,
389             [](const ExprHandle& lhs, const ExprHandle& rhs) {
390               return cast<bool>(lhs >= rhs);
391             });
392       });
393 
394   RegisterNNCLoweringsFunction aten_gt(
395       {"aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)",
396        "aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)"},
397       [](const std::vector<ArgValue>& inputs,
398          const std::vector<ExprHandle>& outputShape,
399          const std::vector<ExprHandle>& outputStrides,
400          const std::optional<ScalarType>& outputType,
401          at::Device device) {
402         return computeTwoOperand(
403             "aten_gt",
404             inputs,
405             outputShape,
406             outputStrides,
407             outputType,
408             [](const ExprHandle& lhs, const ExprHandle& rhs) {
409               return cast<bool>(lhs > rhs);
410             });
411       });
412 
413   RegisterNNCLoweringsFunction aten_le(
414       {"aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)",
415        "aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)"},
416       [](const std::vector<ArgValue>& inputs,
417          const std::vector<ExprHandle>& outputShape,
418          const std::vector<ExprHandle>& outputStrides,
419          const std::optional<ScalarType>& outputType,
420          at::Device device) {
421         return computeTwoOperand(
422             "aten_le",
423             inputs,
424             outputShape,
425             outputStrides,
426             outputType,
427             [](const ExprHandle& lhs, const ExprHandle& rhs) {
428               return cast<bool>(lhs <= rhs);
429             });
430       });
431 
432   RegisterNNCLoweringsFunction aten_lt(
433       {"aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)",
434        "aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)"},
435       [](const std::vector<ArgValue>& inputs,
436          const std::vector<ExprHandle>& outputShape,
437          const std::vector<ExprHandle>& outputStrides,
438          const std::optional<ScalarType>& outputType,
439          at::Device device) {
440         return computeTwoOperand(
441             "aten_lt",
442             inputs,
443             outputShape,
444             outputStrides,
445             outputType,
446             [](const ExprHandle& lhs, const ExprHandle& rhs) {
447               return cast<bool>(lhs < rhs);
448             });
449       });
450 
451   RegisterNNCLoweringsFunction aten_min_pointwise(
452       {"aten::min.other(Tensor self, Tensor other) -> (Tensor)"},
453       [](const std::vector<ArgValue>& inputs,
454          const std::vector<ExprHandle>& outputShape,
455          const std::vector<ExprHandle>& outputStrides,
456          const std::optional<ScalarType>& outputType,
457          at::Device device) {
458         return computeTwoOperand(
459             "aten_min",
460             inputs,
461             outputShape,
462             outputStrides,
463             outputType,
464             [](const ExprHandle& lhs, const ExprHandle& rhs) {
465               return Min::make(boolToInteger(lhs), boolToInteger(rhs), false);
466             });
467       });
468 
469   RegisterNNCLoweringsFunction aten_max_pointwise(
470       {"aten::max.other(Tensor self, Tensor other) -> (Tensor)"},
471       [](const std::vector<ArgValue>& inputs,
472          const std::vector<ExprHandle>& outputShape,
473          const std::vector<ExprHandle>& outputStrides,
474          const std::optional<ScalarType>& outputType,
475          at::Device device) {
476         return computeTwoOperand(
477             "aten_max",
478             inputs,
479             outputShape,
480             outputStrides,
481             outputType,
482             [](const ExprHandle& lhs, const ExprHandle& rhs) {
483               return Max::make(boolToInteger(lhs), boolToInteger(rhs), false);
484             });
485       });
486 
487   RegisterNNCLoweringsFunction aten_masked_fill(
488       {"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
489        "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> (Tensor)"},
490       [](const std::vector<ArgValue>& inputs,
491          const std::vector<ExprHandle>& outputShape,
492          const std::vector<ExprHandle>& outputStrides,
493          const std::optional<ScalarType>& outputType,
494          at::Device device) {
495         return computeThreeOperand(
496             "aten_masked_fill",
497             inputs,
498             outputShape,
499             outputStrides,
500             outputType,
501             [](const ExprHandle& input,
502                const ExprHandle& mask,
503                const ExprHandle& value) {
504               // value needs to promote to input, not vice versa
505               auto val = promoteToDtype(value, input.dtype().scalar_type());
506               return ifThenElse(mask, val, input);
507             },
508             /*promote_inputs*/ false);
509       });
510   RegisterNNCLoweringsFunction aten_clamp(
511       {"aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)",
512        "aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> (Tensor)"},
513       [](const std::vector<ArgValue>& inputs,
514          const std::vector<ExprHandle>& outputShape,
515          const std::vector<ExprHandle>& outputStrides,
516          const std::optional<ScalarType>& outputType,
517          at::Device device) {
518         bool noMin = false;
519         bool noMax = false;
520         if (std::get_if<ArgNone>(&inputs[1])) {
521           noMin = true;
522         }
523 
524         if (std::get_if<ArgNone>(&inputs[2])) {
525           noMax = true;
526         }
527 
528         return computeThreeOperand(
529             "aten_clamp",
530             inputs,
531             outputShape,
532             outputStrides,
533             outputType,
534             [noMin, noMax](
535                 const ExprHandle& in,
536                 const ExprHandle& min,
537                 const ExprHandle& max) {
538               auto cast = [&](const ExprHandle& e) {
539                 return Cast::make(in.dtype(), e);
540               };
541 
542               if (noMin && noMax) {
543                 return in;
544               } else if (noMin) {
545                 auto cmax = cast(max);
546                 return CompareSelect::make(in, cmax, cmax, in, kGT);
547               } else if (noMax) {
548                 auto cmin = cast(min);
549                 return CompareSelect::make(in, cmin, cmin, in, kLT);
550               } else {
551                 auto cmax = cast(max);
552                 auto cmin = cast(min);
553                 return clamp(cmin, cmax, in);
554               }
555             },
556             false /* promote_inputs */);
557       });
558 
559   RegisterNNCLoweringsFunction aten_addcmul(
560       {"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> (Tensor)"},
561       [](const std::vector<ArgValue>& inputs,
562          const std::vector<ExprHandle>& outputShape,
563          const std::vector<ExprHandle>& outputStrides,
564          const std::optional<ScalarType>& outputType,
565          at::Device device) {
566         return computeFourOperand(
567             "aten_addcmul",
568             inputs,
569             outputShape,
570             outputStrides,
571             outputType,
572             [](const ExprHandle& a0,
573                const ExprHandle& a1,
574                const ExprHandle& a2,
575                const ExprHandle& a3) { return a0 + a3 * a1 * a2; });
576       });
577 
578   RegisterNNCLoweringsFunction aten_sigmoid(
579       {"aten::sigmoid(Tensor self) -> (Tensor)"},
580       [](const std::vector<ArgValue>& inputs,
581          const std::vector<ExprHandle>& outputShape,
582          const std::vector<ExprHandle>& outputStrides,
583          const std::optional<ScalarType>& outputType,
584          at::Device device) {
585         // check if the activation is quantized
586         const BufHandle& x = std::get<BufHandle>(inputs[0]);
587         if (x.node()->qscale()) {
588           return computeQuantizedSigmoidExternalCall(
589               inputs, outputShape, outputStrides, outputType, device);
590         }
591         return computeOneOperand(
592             "aten_sigmoid",
593             inputs,
594             outputShape,
595             outputStrides,
596             outputType,
597             [](const ExprHandle& a) {
598               return sigmoid(promoteIntegerToDefaultType(a));
599             });
600       });
601 
602   RegisterNNCLoweringsFunction aten_silu(
603       {"aten::silu(Tensor self) -> (Tensor)"},
604       [](const std::vector<ArgValue>& inputs,
605          const std::vector<ExprHandle>& outputShape,
606          const std::vector<ExprHandle>& outputStrides,
607          const std::optional<ScalarType>& outputType,
608          at::Device device) {
609         return computeOneOperand(
610             "aten_silu",
611             inputs,
612             outputShape,
613             outputStrides,
614             outputType,
615             [](const ExprHandle& a) { return a * sigmoid(a); });
616       });
617 
618   RegisterNNCLoweringsFunction aten_reciprocal(
619       {"aten::reciprocal(Tensor self) -> (Tensor)"},
620       [](const std::vector<ArgValue>& inputs,
621          const std::vector<ExprHandle>& outputShape,
622          const std::vector<ExprHandle>& outputStrides,
623          const std::optional<ScalarType>& outputType,
624          at::Device device) {
625         return computeOneOperand(
626             "aten_reciprocal",
627             inputs,
628             outputShape,
629             outputStrides,
630             outputType,
631             [](const ExprHandle& a) { return ExprHandle(1.0f) / a; });
632       });
633 
634   RegisterNNCLoweringsFunction aten_neg(
635       {"aten::neg(Tensor self) -> (Tensor)"},
636       [](const std::vector<ArgValue>& inputs,
637          const std::vector<ExprHandle>& outputShape,
638          const std::vector<ExprHandle>& outputStrides,
639          const std::optional<ScalarType>& outputType,
640          at::Device device) {
641         return computeOneOperand(
642             "aten_neg",
643             inputs,
644             outputShape,
645             outputStrides,
646             outputType,
647             [](const ExprHandle& a) { return ExprHandle(-0) - a; });
648       });
649 
650   RegisterNNCLoweringsFunction aten_isnan(
651       {"aten::isnan(Tensor self) -> (Tensor)"},
652       [](const std::vector<ArgValue>& inputs,
653          const std::vector<ExprHandle>& outputShape,
654          const std::vector<ExprHandle>& outputStrides,
655          const std::optional<ScalarType>& outputType,
656          at::Device device) {
657         return computeOneOperand(
658             "aten_isnan",
659             inputs,
660             outputShape,
661             outputStrides,
662             outputType,
663             [](const ExprHandle& a) {
664               if (!a.dtype().is_floating_point()) {
665                 return IntImm::make(0);
666               }
667               return isnan(a);
668             });
669       });
670 
671   RegisterNNCLoweringsFunction aten_relu(
672       {"aten::relu(Tensor self) -> (Tensor)"},
673       [](const std::vector<ArgValue>& inputs,
674          const std::vector<ExprHandle>& outputShape,
675          const std::vector<ExprHandle>& outputStrides,
676          const std::optional<ScalarType>& outputType,
677          at::Device device) {
678         auto A = std::get<BufHandle>(inputs[0]);
679         if (A.node()->qscale()) {
680           return computeQuantizedRelu(
681               inputs, outputShape, outputStrides, outputType, device);
682         }
683         return computeOneOperand(
684             "aten_relu",
685             inputs,
686             outputShape,
687             outputStrides,
688             outputType,
689             [](const ExprHandle& a) {
690               auto zero = Cast::make(a.dtype(), 0);
691               return CompareSelect::make(a, zero, zero, a, kLT);
692             });
693       });
694 
695   RegisterNNCLoweringsFunction aten_leaky_relu(
696       {"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)"},
697       [](const std::vector<ArgValue>& inputs,
698          const std::vector<ExprHandle>& outputShape,
699          const std::vector<ExprHandle>& outputStrides,
700          const std::optional<ScalarType>& outputType,
701          at::Device device) {
702         return computeTwoOperand(
703             "aten_leaky_relu",
704             inputs,
705             outputShape,
706             outputStrides,
707             outputType,
708             [](const ExprHandle& a, const ExprHandle& negative_slope) {
709               auto neg_slope = Cast::make(a.dtype(), negative_slope);
710               auto zero = Cast::make(a.dtype(), 0);
711               auto one = Cast::make(a.dtype(), 1);
712               auto cs = CompareSelect::make(a, zero, one, neg_slope, kGT);
713               return a * cs;
714             });
715       });
716 
717   RegisterNNCLoweringsFunction aten_relu6(
718       {"aten::relu6(Tensor self) -> (Tensor)"},
719       [](const std::vector<ArgValue>& inputs,
720          const std::vector<ExprHandle>& outputShape,
721          const std::vector<ExprHandle>& outputStrides,
722          const std::optional<ScalarType>& outputType,
723          at::Device device) {
724         return computeOneOperand(
725             "aten_relu6",
726             inputs,
727             outputShape,
728             outputStrides,
729             outputType,
730             [](const ExprHandle& a) {
731               auto zero = Cast::make(a.dtype(), 0);
732               auto six = Cast::make(a.dtype(), 6.);
733               return clamp(zero, six, a);
734             });
735       });
736 
737   RegisterNNCLoweringsFunction aten_gelu(
738       {"aten::gelu(Tensor self, *, str approximate='none') -> (Tensor)"},
739       [](const std::vector<ArgValue>& inputs,
740          const std::vector<ExprHandle>& outputShape,
741          const std::vector<ExprHandle>& outputStrides,
742          const std::optional<ScalarType>& outputType,
743          at::Device device) {
744         const auto& kApproximate = std::get<std::string>(inputs[1]);
745         std::vector<ArgValue> operands = {inputs.front()};
746         if (at::native::get_gelutype_enum(kApproximate) ==
747             at::native::GeluType::Tanh) {
748           // approximate == 'tanh'
749           return computeOneOperand(
750               "aten_tanh_gelu",
751               operands,
752               outputShape,
753               outputStrides,
754               outputType,
755               [](const ExprHandle& a) {
756                 auto one = Cast::make(a.dtype(), 1.);
757                 auto point_five = Cast::make(a.dtype(), .5);
758                 auto beta = Cast::make(a.dtype(), M_SQRT2 * M_2_SQRTPI * 0.5);
759                 auto kappa = Cast::make(a.dtype(), 0.044715);
760                 auto a_cube = a * a * a;
761                 auto inner = beta * (a + kappa * a_cube);
762                 return point_five * a * (one + tanh(inner));
763               });
764         } else {
765           // approximate == 'none'
766           return computeOneOperand(
767               "aten_gelu",
768               operands,
769               outputShape,
770               outputStrides,
771               outputType,
772               [](const ExprHandle& a) {
773                 auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2);
774                 auto one = Cast::make(a.dtype(), 1.);
775                 auto point_five = Cast::make(a.dtype(), .5);
776                 return a * point_five * (one + erf(a * m_sqrt1_2));
777               });
778         }
779       });
780 
781   RegisterNNCLoweringsFunction aten_batch_norm(
782       {"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)"},
783       computeBatchNorm);
784 
785   RegisterNNCLoweringsFunction aten_log(
786       {"aten::log(Tensor self) -> (Tensor)"},
787       [](const std::vector<ArgValue>& inputs,
788          const std::vector<ExprHandle>& outputShape,
789          const std::vector<ExprHandle>& outputStrides,
790          const std::optional<ScalarType>& outputType,
791          at::Device device) {
792         return computeOneOperand(
793             "aten_log",
794             inputs,
795             outputShape,
796             outputStrides,
797             outputType,
798             [](const ExprHandle& a) {
799               return log(promoteIntegerToDefaultType(a));
800             });
801       });
802 
803   RegisterNNCLoweringsFunction aten_log10(
804       {"aten::log10(Tensor self) -> (Tensor)"},
805       [](const std::vector<ArgValue>& inputs,
806          const std::vector<ExprHandle>& outputShape,
807          const std::vector<ExprHandle>& outputStrides,
808          const std::optional<ScalarType>& outputType,
809          at::Device device) {
810         return computeOneOperand(
811             "aten_log10",
812             inputs,
813             outputShape,
814             outputStrides,
815             outputType,
816             [](const ExprHandle& a) {
817               return log10(promoteIntegerToDefaultType(a));
818             });
819       });
820 
821   RegisterNNCLoweringsFunction aten_log1p(
822       {"aten::log1p(Tensor self) -> (Tensor)"},
823       [](const std::vector<ArgValue>& inputs,
824          const std::vector<ExprHandle>& outputShape,
825          const std::vector<ExprHandle>& outputStrides,
826          const std::optional<ScalarType>& outputType,
827          at::Device device) {
828         return computeOneOperand(
829             "aten_log1p",
830             inputs,
831             outputShape,
832             outputStrides,
833             outputType,
834             [](const ExprHandle& a) {
835               return log1p(promoteIntegerToDefaultType(a));
836             });
837       });
838 
839   RegisterNNCLoweringsFunction aten_log2(
840       {"aten::log2(Tensor self) -> (Tensor)"},
841       [](const std::vector<ArgValue>& inputs,
842          const std::vector<ExprHandle>& outputShape,
843          const std::vector<ExprHandle>& outputStrides,
844          const std::optional<ScalarType>& outputType,
845          at::Device device) {
846         return computeOneOperand(
847             "aten_log2",
848             inputs,
849             outputShape,
850             outputStrides,
851             outputType,
852             [](const ExprHandle& a) {
853               return log2(promoteIntegerToDefaultType(a));
854             });
855       });
856 
857   RegisterNNCLoweringsFunction aten_exp(
858       {"aten::exp(Tensor self) -> (Tensor)"},
859       [](const std::vector<ArgValue>& inputs,
860          const std::vector<ExprHandle>& outputShape,
861          const std::vector<ExprHandle>& outputStrides,
862          const std::optional<ScalarType>& outputType,
863          at::Device device) {
864         return computeOneOperand(
865             "aten_exp",
866             inputs,
867             outputShape,
868             outputStrides,
869             outputType,
870             [](const ExprHandle& a) {
871               return exp(promoteIntegerToDefaultType(a));
872             });
873       });
874 
875   RegisterNNCLoweringsFunction aten_expm1(
876       {"aten::expm1(Tensor self) -> (Tensor)"},
877       [](const std::vector<ArgValue>& inputs,
878          const std::vector<ExprHandle>& outputShape,
879          const std::vector<ExprHandle>& outputStrides,
880          const std::optional<ScalarType>& outputType,
881          at::Device device) {
882         return computeOneOperand(
883             "aten_expm1",
884             inputs,
885             outputShape,
886             outputStrides,
887             outputType,
888             [](const ExprHandle& a) {
889               return expm1(promoteIntegerToDefaultType(a));
890             });
891       });
892 
893   RegisterNNCLoweringsFunction aten_erf(
894       {"aten::erf(Tensor self) -> (Tensor)"},
895       [](const std::vector<ArgValue>& inputs,
896          const std::vector<ExprHandle>& outputShape,
897          const std::vector<ExprHandle>& outputStrides,
898          const std::optional<ScalarType>& outputType,
899          at::Device device) {
900         return computeOneOperand(
901             "aten_erf",
902             inputs,
903             outputShape,
904             outputStrides,
905             outputType,
906             [](const ExprHandle& a) {
907               return erf(promoteIntegerToDefaultType(a));
908             });
909       });
910 
911   RegisterNNCLoweringsFunction aten_erfc(
912       {"aten::erfc(Tensor self) -> (Tensor)"},
913       [](const std::vector<ArgValue>& inputs,
914          const std::vector<ExprHandle>& outputShape,
915          const std::vector<ExprHandle>& outputStrides,
916          const std::optional<ScalarType>& outputType,
917          at::Device device) {
918         return computeOneOperand(
919             "aten_erfc",
920             inputs,
921             outputShape,
922             outputStrides,
923             outputType,
924             [](const ExprHandle& a) {
925               return erfc(promoteIntegerToDefaultType(a));
926             });
927       });
928 
929   RegisterNNCLoweringsFunction aten_cos(
930       {"aten::cos(Tensor self) -> (Tensor)"},
931       [](const std::vector<ArgValue>& inputs,
932          const std::vector<ExprHandle>& outputShape,
933          const std::vector<ExprHandle>& outputStrides,
934          const std::optional<ScalarType>& outputType,
935          at::Device device) {
936         return computeOneOperand(
937             "aten_cos",
938             inputs,
939             outputShape,
940             outputStrides,
941             outputType,
942             [](const ExprHandle& a) {
943               return cos(promoteIntegerToDefaultType(a));
944             });
945       });
946 
947   RegisterNNCLoweringsFunction aten_sin(
948       {"aten::sin(Tensor self) -> (Tensor)"},
949       [](const std::vector<ArgValue>& inputs,
950          const std::vector<ExprHandle>& outputShape,
951          const std::vector<ExprHandle>& outputStrides,
952          const std::optional<ScalarType>& outputType,
953          at::Device device) {
954         return computeOneOperand(
955             "aten_sin",
956             inputs,
957             outputShape,
958             outputStrides,
959             outputType,
960             [](const ExprHandle& a) {
961               return sin(promoteIntegerToDefaultType(a));
962             });
963       });
964 
965   RegisterNNCLoweringsFunction aten_tan(
966       {"aten::tan(Tensor self) -> (Tensor)"},
967       [](const std::vector<ArgValue>& inputs,
968          const std::vector<ExprHandle>& outputShape,
969          const std::vector<ExprHandle>& outputStrides,
970          const std::optional<ScalarType>& outputType,
971          at::Device device) {
972         return computeOneOperand(
973             "aten_tan",
974             inputs,
975             outputShape,
976             outputStrides,
977             outputType,
978             [](const ExprHandle& a) {
979               return tan(promoteIntegerToDefaultType(a));
980             });
981       });
982 
983   RegisterNNCLoweringsFunction aten_type_as(
984       {"aten::type_as(Tensor self, Tensor other) -> (Tensor)"},
985       [](const std::vector<ArgValue>& inputs,
986          const std::vector<ExprHandle>& outputShape,
987          const std::vector<ExprHandle>& outputStrides,
988          const std::optional<ScalarType>& outputType,
989          at::Device device) {
990         const BufHandle& rhs = std::get<BufHandle>(inputs[1]);
991         auto dtype = rhs.dtype();
992         return computeOneOperand(
993             "aten_type_as",
994             inputs,
995             outputShape,
996             outputStrides,
997             outputType,
998             [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); });
999       });
1000 
1001   RegisterNNCLoweringsFunction aten_pow(
1002       {"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
1003        "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)",
1004        "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor"},
1005       [](const std::vector<ArgValue>& inputs,
1006          const std::vector<ExprHandle>& outputShape,
1007          const std::vector<ExprHandle>& outputStrides,
1008          const std::optional<ScalarType>& outputType,
1009          at::Device device) {
1010         return computeTwoOperand(
1011             "aten_pow",
1012             inputs,
1013             outputShape,
1014             outputStrides,
1015             outputType,
1016             [](const ExprHandle& lhs, const ExprHandle& rhs) {
1017               if (!rhs.node()->isConstant()) {
1018                 return pow(lhs, rhs);
1019               }
1020               double val =
1021                   immediateAs<double>(IRSimplifier::simplify(rhs.node()));
1022 
1023               if (val == 1.0f) {
1024                 return lhs;
1025               } else if (val == 2.0f) { // NOLINT
1026                 return lhs * lhs;
1027               } else if (val == 3.0f) { // NOLINT
1028                 return (lhs * lhs) * lhs;
1029               } else if (val == 4.0f) { // NOLINT
1030                 ExprHandle tmp = lhs * lhs;
1031                 return tmp * tmp;
1032               } else if (val == 0.5f) { // NOLINT
1033                 return sqrt(lhs);
1034               } else if (val == 0.0f) {
1035                 return ExprHandle(1.0f);
1036               } else if (val == -0.5f) { // NOLINT
1037                 return rsqrt(lhs);
1038               } else if (val == -1.0f) {
1039                 return ExprHandle(1.0f) / lhs;
1040               } else if (val == -2.0f) { // NOLINT
1041                 return ExprHandle(1.0f) / (lhs * lhs);
1042               }
1043               return pow(lhs, rhs);
1044             });
1045       });
1046 
1047   RegisterNNCLoweringsFunction aten_fmod(
1048       {"aten::fmod.Scalar(Tensor self, Scalar other) -> (Tensor)",
1049        "aten::fmod.Tensor(Tensor self, Tensor other) -> (Tensor)"},
1050       [](const std::vector<ArgValue>& inputs,
1051          const std::vector<ExprHandle>& outputShape,
1052          const std::vector<ExprHandle>& outputStrides,
1053          const std::optional<ScalarType>& outputType,
1054          at::Device device) {
1055         return computeTwoOperand(
1056             "aten_fmod",
1057             inputs,
1058             outputShape,
1059             outputStrides,
1060             outputType,
1061             [](const ExprHandle& lhs, const ExprHandle& rhs) {
1062               return fmod(promoteHalfToFloat(lhs), promoteHalfToFloat(rhs));
1063             });
1064       });
1065 
1066   RegisterNNCLoweringsFunction aten_lerp(
1067       {"aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> (Tensor)",
1068        "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> (Tensor)"},
1069       [](const std::vector<ArgValue>& inputs,
1070          const std::vector<ExprHandle>& outputShape,
1071          const std::vector<ExprHandle>& outputStrides,
1072          const std::optional<ScalarType>& outputType,
1073          at::Device device) {
1074         return computeThreeOperand(
1075             "aten_lerp",
1076             inputs,
1077             outputShape,
1078             outputStrides,
1079             outputType,
1080             [](const ExprHandle& a,
1081                const ExprHandle& end,
1082                const ExprHandle& weight) { return a + weight * (end - a); });
1083       });
1084 
1085   RegisterNNCLoweringsFunction aten_remainder(
1086       {"aten::remainder.Scalar(Tensor self, Scalar other) -> (Tensor)",
1087        "aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> (Tensor)",
1088        "aten::remainder.Tensor(Tensor self, Tensor other) -> (Tensor)"},
1089       [](const std::vector<ArgValue>& inputs,
1090          const std::vector<ExprHandle>& outputShape,
1091          const std::vector<ExprHandle>& outputStrides,
1092          const std::optional<ScalarType>& outputType,
1093          at::Device device) {
1094         auto imodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) {
1095           return Mod::make(lhs, rhs);
1096         };
1097         auto fmodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) {
1098           auto lhs_t = promoteHalfToFloat(lhs);
1099           auto rhs_t = promoteHalfToFloat(rhs);
1100           return fmod((rhs_t + fmod(lhs_t, rhs_t)), rhs_t);
1101         };
1102         {
1103           auto const& shape =
1104               broadcastShapes(valueShape(inputs[0]), valueShape(inputs[1]));
1105           return Compute(
1106               "aten_remainder", shape, [&](const std::vector<VarHandle>& axes) {
1107                 std::vector<ExprHandle> indices(axes.begin(), axes.end());
1108                 std::vector<ExprHandle> exprInputs = {
1109                     tensorOrConstant(inputs[0], indices),
1110                     tensorOrConstant(inputs[1], indices),
1111                 };
1112 
1113                 promoteInputs(exprInputs);
1114                 bool allInt = true;
1115                 for (auto& e : exprInputs) {
1116                   if (e.dtype().is_floating_point()) {
1117                     allInt = false;
1118                     break;
1119                   }
1120                 }
1121                 if (allInt) {
1122                   return demoteOutput(
1123                       imodImpl(exprInputs[0], exprInputs[1]), outputType);
1124                 } else {
1125                   return demoteOutput(
1126                       fmodImpl(exprInputs[0], exprInputs[1]), outputType);
1127                 }
1128               });
1129         }
1130       });
1131 
1132   RegisterNNCLoweringsFunction prim_ConstantChunk(
1133       {"prim::ConstantChunk(...) -> (...)"}, computeChunk);
1134 
1135   RegisterNNCLoweringsFunction aten_acos(
1136       {"aten::acos(Tensor self) -> (Tensor)"},
1137       [](const std::vector<ArgValue>& inputs,
1138          const std::vector<ExprHandle>& outputShape,
1139          const std::vector<ExprHandle>& outputStrides,
1140          const std::optional<ScalarType>& outputType,
1141          at::Device device) {
1142         return computeOneOperand(
1143             "aten_acos",
1144             inputs,
1145             outputShape,
1146             outputStrides,
1147             outputType,
1148             [](const ExprHandle& a) {
1149               return acos(promoteIntegerToDefaultType(a));
1150             });
1151       });
1152 
1153   RegisterNNCLoweringsFunction aten_asin(
1154       {"aten::asin(Tensor self) -> (Tensor)"},
1155       [](const std::vector<ArgValue>& inputs,
1156          const std::vector<ExprHandle>& outputShape,
1157          const std::vector<ExprHandle>& outputStrides,
1158          const std::optional<ScalarType>& outputType,
1159          at::Device device) {
1160         return computeOneOperand(
1161             "aten_asin",
1162             inputs,
1163             outputShape,
1164             outputStrides,
1165             outputType,
1166             [](const ExprHandle& a) {
1167               return asin(promoteIntegerToDefaultType(a));
1168             });
1169       });
1170 
1171   RegisterNNCLoweringsFunction aten_cosh(
1172       {"aten::cosh(Tensor self) -> (Tensor)"},
1173       [](const std::vector<ArgValue>& inputs,
1174          const std::vector<ExprHandle>& outputShape,
1175          const std::vector<ExprHandle>& outputStrides,
1176          const std::optional<ScalarType>& outputType,
1177          at::Device device) {
1178         return computeOneOperand(
1179             "aten_cosh",
1180             inputs,
1181             outputShape,
1182             outputStrides,
1183             outputType,
1184             [](const ExprHandle& a) {
1185               return cosh(promoteIntegerToDefaultType(a));
1186             });
1187       });
1188 
1189   RegisterNNCLoweringsFunction aten_sinh(
1190       {"aten::sinh(Tensor self) -> (Tensor)"},
1191       [](const std::vector<ArgValue>& inputs,
1192          const std::vector<ExprHandle>& outputShape,
1193          const std::vector<ExprHandle>& outputStrides,
1194          const std::optional<ScalarType>& outputType,
1195          at::Device device) {
1196         return computeOneOperand(
1197             "aten_sinh",
1198             inputs,
1199             outputShape,
1200             outputStrides,
1201             outputType,
1202             [](const ExprHandle& a) {
1203               return sinh(promoteIntegerToDefaultType(a));
1204             });
1205       });
1206 
1207   RegisterNNCLoweringsFunction aten_atan(
1208       {"aten::atan(Tensor self) -> (Tensor)"},
1209       [](const std::vector<ArgValue>& inputs,
1210          const std::vector<ExprHandle>& outputShape,
1211          const std::vector<ExprHandle>& outputStrides,
1212          const std::optional<ScalarType>& outputType,
1213          at::Device device) {
1214         return computeOneOperand(
1215             "aten_atan",
1216             inputs,
1217             outputShape,
1218             outputStrides,
1219             outputType,
1220             [](const ExprHandle& a) {
1221               return atan(promoteIntegerToDefaultType(a));
1222             });
1223       });
1224 
1225   RegisterNNCLoweringsFunction aten_atan2(
1226       {"aten::atan2(Tensor self, Tensor other) -> (Tensor)"},
1227       [](const std::vector<ArgValue>& inputs,
1228          const std::vector<ExprHandle>& outputShape,
1229          const std::vector<ExprHandle>& outputStrides,
1230          const std::optional<ScalarType>& outputType,
1231          at::Device device) {
1232         return computeTwoOperand(
1233             "aten_atan2",
1234             inputs,
1235             outputShape,
1236             outputStrides,
1237             outputType,
1238             [](const ExprHandle& lhs, const ExprHandle& rhs) {
1239               return atan2(
1240                   promoteIntegerToDefaultType(lhs),
1241                   promoteIntegerToDefaultType(rhs));
1242             });
1243       });
1244 
1245   RegisterNNCLoweringsFunction aten_tanh(
1246       {"aten::tanh(Tensor self) -> (Tensor)"},
1247       [](const std::vector<ArgValue>& inputs,
1248          const std::vector<ExprHandle>& outputShape,
1249          const std::vector<ExprHandle>& outputStrides,
1250          const std::optional<ScalarType>& outputType,
1251          at::Device device) {
1252         return computeOneOperand(
1253             "aten_tanh",
1254             inputs,
1255             outputShape,
1256             outputStrides,
1257             outputType,
1258             [](const ExprHandle& a) {
1259               return tanh(promoteIntegerToDefaultType(a));
1260             });
1261       });
1262 
1263   RegisterNNCLoweringsFunction aten_hardtanh(
1264       {"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)"},
1265       [](const std::vector<ArgValue>& inputs,
1266          const std::vector<ExprHandle>& outputShape,
1267          const std::vector<ExprHandle>& outputStrides,
1268          const std::optional<ScalarType>& outputType,
1269          at::Device device) {
1270         return computeThreeOperand(
1271             "aten_hardtanh",
1272             inputs,
1273             outputShape,
1274             outputStrides,
1275             outputType,
1276             [](const ExprHandle& a,
1277                const ExprHandle& min_val,
1278                const ExprHandle& max_val) {
1279               auto mm = CompareSelect::make(a, min_val, min_val, a, kLT);
1280               return CompareSelect::make(mm, max_val, max_val, mm, kGT);
1281             });
1282       });
1283 
1284   RegisterNNCLoweringsFunction aten_softplus(
1285       {"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> (Tensor)"},
1286       [](const std::vector<ArgValue>& inputs,
1287          const std::vector<ExprHandle>& outputShape,
1288          const std::vector<ExprHandle>& outputStrides,
1289          const std::optional<ScalarType>& outputType,
1290          at::Device device) {
1291         return computeThreeOperand(
1292             "aten_softplus",
1293             inputs,
1294             outputShape,
1295             outputStrides,
1296             outputType,
1297             [](const ExprHandle& a,
1298                const ExprHandle& beta,
1299                const ExprHandle& threshold) {
1300               auto beta_promoted = Cast::make(a.dtype(), beta);
1301               auto threshold_promoted = Cast::make(a.dtype(), threshold);
1302               auto beta_a = beta_promoted * a;
1303               return CompareSelect::make(
1304                   beta_a,
1305                   threshold_promoted,
1306                   a,
1307                   log1p(exp(beta_a)) / beta_promoted,
1308                   kGT);
1309             });
1310       });
1311 
1312   RegisterNNCLoweringsFunction aten_mish(
1313       {"aten::mish(Tensor self) -> (Tensor)"},
1314       [](const std::vector<ArgValue>& inputs,
1315          const std::vector<ExprHandle>& outputShape,
1316          const std::vector<ExprHandle>& outputStrides,
1317          const std::optional<ScalarType>& outputType,
1318          at::Device device) {
1319         return computeOneOperand(
1320             "aten_mish",
1321             inputs,
1322             outputShape,
1323             outputStrides,
1324             outputType,
1325             [](const ExprHandle& a) {
1326               auto default_type_a = promoteIntegerToDefaultType(a);
1327               return default_type_a * tanh(log1p(exp(default_type_a)));
1328             });
1329       });
1330 
1331   RegisterNNCLoweringsFunction aten_elu(
1332       {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)"},
1333       [](const std::vector<ArgValue>& inputs,
1334          const std::vector<ExprHandle>& outputShape,
1335          const std::vector<ExprHandle>& outputStrides,
1336          const std::optional<ScalarType>& outputType,
1337          at::Device device) {
1338         return computeFourOperand(
1339             "aten_elu",
1340             inputs,
1341             outputShape,
1342             outputStrides,
1343             outputType,
1344             [](const ExprHandle& a,
1345                const ExprHandle& alpha,
1346                const ExprHandle& scale,
1347                const ExprHandle& input_scale) {
1348               auto zero = Cast::make(a.dtype(), 0);
1349               auto one = Cast::make(a.dtype(), 1);
1350 
1351               auto poscoef = Cast::make(a.dtype(), scale);
1352               auto negiptcoef = Cast::make(a.dtype(), input_scale);
1353               auto negcoef = Cast::make(a.dtype(), alpha) * poscoef;
1354 
1355               return CompareSelect::make(
1356                   a,
1357                   zero,
1358                   a * poscoef,
1359                   (exp(a * negiptcoef) - one) * negcoef,
1360                   kGT);
1361             });
1362       });
1363 
1364   RegisterNNCLoweringsFunction aten_hardsigmoid(
1365       {"aten::hardsigmoid(Tensor self) -> (Tensor)"},
1366       [](const std::vector<ArgValue>& inputs,
1367          const std::vector<ExprHandle>& outputShape,
1368          const std::vector<ExprHandle>& outputStrides,
1369          const std::optional<ScalarType>& outputType,
1370          at::Device device) {
1371         return computeOneOperand(
1372             "aten_hardsigmoid",
1373             inputs,
1374             outputShape,
1375             outputStrides,
1376             outputType,
1377             [](const ExprHandle& a) {
1378               auto zero = Cast::make(a.dtype(), 0.0);
1379               auto three = Cast::make(a.dtype(), 3.0);
1380               auto six = Cast::make(a.dtype(), 6.0);
1381               return clamp(zero, six, a + three) / six;
1382             });
1383       });
1384 
1385   RegisterNNCLoweringsFunction aten_hardswish(
1386       {"aten::hardswish(Tensor self) -> (Tensor)"},
1387       [](const std::vector<ArgValue>& inputs,
1388          const std::vector<ExprHandle>& outputShape,
1389          const std::vector<ExprHandle>& outputStrides,
1390          const std::optional<ScalarType>& outputType,
1391          at::Device device) {
1392         return computeOneOperand(
1393             "aten_hardswish",
1394             inputs,
1395             outputShape,
1396             outputStrides,
1397             outputType,
1398             [](const ExprHandle& a) {
1399               //  x * torch.clamp(x + 3.0, 0.0, 6.0) / 6.0
1400               auto zero = Cast::make(a.dtype(), 0.);
1401               auto three = Cast::make(a.dtype(), 3.);
1402               auto six = Cast::make(a.dtype(), 6.);
1403 
1404               return a * clamp(zero, six, a + three) / six;
1405             });
1406       });
1407 
1408   RegisterNNCLoweringsFunction aten_hardshrink(
1409       {"aten::hardshrink(Tensor self, Scalar lambd=0.5) -> (Tensor)"},
1410       [](const std::vector<ArgValue>& inputs,
1411          const std::vector<ExprHandle>& outputShape,
1412          const std::vector<ExprHandle>& outputStrides,
1413          const std::optional<ScalarType>& outputType,
1414          at::Device device) {
1415         return computeTwoOperand(
1416             "aten_hardshrink",
1417             inputs,
1418             outputShape,
1419             outputStrides,
1420             outputType,
1421             [](const ExprHandle& a, const ExprHandle& lambd) {
1422               auto pos_clambd = Cast::make(a.dtype(), lambd);
1423               auto neg_clambd =
1424                   Cast::make(a.dtype(), ExprHandle(-0)) - pos_clambd;
1425               auto zero = Cast::make(a.dtype(), 0);
1426               auto mm = CompareSelect::make(a, neg_clambd, a, zero, kLT);
1427               return CompareSelect::make(a, pos_clambd, a, mm, kGT);
1428             });
1429       });
1430 
1431   RegisterNNCLoweringsFunction aten_sqrt(
1432       {"aten::sqrt(Tensor self) -> (Tensor)"},
1433       [](const std::vector<ArgValue>& inputs,
1434          const std::vector<ExprHandle>& outputShape,
1435          const std::vector<ExprHandle>& outputStrides,
1436          const std::optional<ScalarType>& outputType,
1437          at::Device device) {
1438         return computeOneOperand(
1439             "aten_sqrt",
1440             inputs,
1441             outputShape,
1442             outputStrides,
1443             outputType,
1444             [](const ExprHandle& a) {
1445               return tensorexpr::sqrt(promoteIntegerToDefaultType(a));
1446             });
1447       });
1448 
1449   RegisterNNCLoweringsFunction aten_rsqrt(
1450       {"aten::rsqrt(Tensor self) -> (Tensor)"},
1451       [](const std::vector<ArgValue>& inputs,
1452          const std::vector<ExprHandle>& outputShape,
1453          const std::vector<ExprHandle>& outputStrides,
1454          const std::optional<ScalarType>& outputType,
1455          at::Device device) {
1456         return computeOneOperand(
1457             "aten_rsqrt",
1458             inputs,
1459             outputShape,
1460             outputStrides,
1461             outputType,
1462             [](const ExprHandle& a) {
1463               return rsqrt(promoteIntegerToDefaultType(a));
1464             });
1465       });
1466 
1467   RegisterNNCLoweringsFunction aten_abs(
1468       {"aten::abs(Tensor self) -> (Tensor)"},
1469       [](const std::vector<ArgValue>& inputs,
1470          const std::vector<ExprHandle>& outputShape,
1471          const std::vector<ExprHandle>& outputStrides,
1472          const std::optional<ScalarType>& outputType,
1473          at::Device device) {
1474         return computeOneOperand(
1475             "aten_abs",
1476             inputs,
1477             outputShape,
1478             outputStrides,
1479             outputType,
1480             [](const ExprHandle& a) {
1481               return tensorexpr::abs(promoteHalfToFloat(a));
1482             },
1483             kIntegralTypes | kFloatingPointTypes | kBoolType);
1484       });
1485 
1486   RegisterNNCLoweringsFunction aten_sign(
1487       {"aten::sign(Tensor self) -> (Tensor)"},
1488       [](const std::vector<ArgValue>& inputs,
1489          const std::vector<ExprHandle>& outputShape,
1490          const std::vector<ExprHandle>& outputStrides,
1491          const std::optional<ScalarType>& outputType,
1492          at::Device device) { return computeSign(inputs, outputShape); });
1493 
1494   RegisterNNCLoweringsFunction aten_ceil(
1495       {"aten::ceil(Tensor self) -> (Tensor)"},
1496       [](const std::vector<ArgValue>& inputs,
1497          const std::vector<ExprHandle>& outputShape,
1498          const std::vector<ExprHandle>& outputStrides,
1499          const std::optional<ScalarType>& outputType,
1500          at::Device device) {
1501         return computeOneOperand(
1502             "aten_ceil",
1503             inputs,
1504             outputShape,
1505             outputStrides,
1506             outputType,
1507             [](const ExprHandle& a) { return ceil(a); });
1508       });
1509 
1510   RegisterNNCLoweringsFunction aten_floor(
1511       {"aten::floor(Tensor self) -> (Tensor)"},
1512       [](const std::vector<ArgValue>& inputs,
1513          const std::vector<ExprHandle>& outputShape,
1514          const std::vector<ExprHandle>& outputStrides,
1515          const std::optional<ScalarType>& outputType,
1516          at::Device device) {
1517         return computeOneOperand(
1518             "aten_floor",
1519             inputs,
1520             outputShape,
1521             outputStrides,
1522             outputType,
1523             [](const ExprHandle& a) { return floor(a); });
1524       });
1525 
1526   RegisterNNCLoweringsFunction aten_round(
1527       {"aten::round(Tensor self) -> (Tensor)"},
1528       [](const std::vector<ArgValue>& inputs,
1529          const std::vector<ExprHandle>& outputShape,
1530          const std::vector<ExprHandle>& outputStrides,
1531          const std::optional<ScalarType>& outputType,
1532          at::Device device) {
1533         return computeOneOperand(
1534             "aten_round",
1535             inputs,
1536             outputShape,
1537             outputStrides,
1538             outputType,
1539             [](const ExprHandle& a) { return round(a); });
1540       });
1541 
1542   RegisterNNCLoweringsFunction aten_trunc(
1543       {"aten::trunc(Tensor self) -> (Tensor)"},
1544       [](const std::vector<ArgValue>& inputs,
1545          const std::vector<ExprHandle>& outputShape,
1546          const std::vector<ExprHandle>& outputStrides,
1547          const std::optional<ScalarType>& outputType,
1548          at::Device device) {
1549         return computeOneOperand(
1550             "aten_trunc",
1551             inputs,
1552             outputShape,
1553             outputStrides,
1554             outputType,
1555             [](const ExprHandle& a) { return trunc(a); });
1556       });
1557 
1558   RegisterNNCLoweringsFunction aten__cast_Float(
1559       {"aten::_cast_Float(Tensor self, bool non_blocking=False) -> (Tensor)"},
1560       [](const std::vector<ArgValue>& inputs,
1561          const std::vector<ExprHandle>& outputShape,
1562          const std::vector<ExprHandle>& outputStrides,
1563          const std::optional<ScalarType>& outputType,
1564          at::Device device) {
1565         return computeOneOperand(
1566             "aten_cast_float",
1567             inputs,
1568             outputShape,
1569             outputStrides,
1570             outputType,
1571             [](const ExprHandle& a) { return cast<float>(a); });
1572       });
1573 
1574   RegisterNNCLoweringsFunction aten_to(
1575       {"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
1576        "aten::to.dtype_layout(Tensor(a) self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
1577        "aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))",
1578        "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
1579        "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)",
1580        "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)",
1581        "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)"},
1582       [](const std::vector<ArgValue>& inputs,
1583          const std::vector<ExprHandle>& outputShape,
1584          const std::vector<ExprHandle>& outputStrides,
1585          const std::optional<ScalarType>& outputType,
1586          at::Device device) {
1587         // see handling of aten::to in tensorexpr_fuser.cpp for why we only
1588         // need to handle the first input
1589         return computeOneOperand(
1590             "aten_to",
1591             {inputs[0]},
1592             outputShape,
1593             outputStrides,
1594             outputType,
1595             [outputType](const ExprHandle& a) {
1596               TORCH_INTERNAL_ASSERT(
1597                   outputType, buildErrorMessage("Output type is null."));
1598               return Cast::make(ToDtype(*outputType), a);
1599             });
1600       });
1601 
1602   RegisterNNCLoweringsFunction aten_threshold(
1603       {"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> (Tensor)"},
1604       [](const std::vector<ArgValue>& inputs,
1605          const std::vector<ExprHandle>& outputShape,
1606          const std::vector<ExprHandle>& outputStrides,
1607          const std::optional<ScalarType>& outputType,
1608          at::Device device) {
1609         return computeThreeOperand(
1610             "aten_threshold",
1611             inputs,
1612             outputShape,
1613             outputStrides,
1614             outputType,
1615             [](const ExprHandle& a,
1616                const ExprHandle& threshold,
1617                const ExprHandle& value) {
1618               return ifThenElse(
1619                   CompareSelect::make(a, threshold, kLE), value, a);
1620             });
1621       });
1622 
1623   RegisterNNCLoweringsFunction aten_where(
1624       {"aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> (Tensor)",
1625        "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> (Tensor)",
1626        "aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
1627        "aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor"},
1628       [](const std::vector<ArgValue>& inputs,
1629          const std::vector<ExprHandle>& outputShape,
1630          const std::vector<ExprHandle>& outputStrides,
1631          const std::optional<ScalarType>& outputType,
1632          at::Device device) {
1633         return computeConditionWithTwoOperand(
1634             "aten_where",
1635             inputs,
1636             outputShape,
1637             outputStrides,
1638             outputType,
1639             [](const ExprHandle& a0,
1640                const ExprHandle& a1,
1641                const ExprHandle& a2) { return ifThenElse(a0, a1, a2); });
1642       });
1643 
1644   RegisterNNCLoweringsFunction aten_frac(
1645       {"aten::frac(Tensor self) -> (Tensor)"},
1646       [](const std::vector<ArgValue>& inputs,
1647          const std::vector<ExprHandle>& outputShape,
1648          const std::vector<ExprHandle>& outputStrides,
1649          const std::optional<ScalarType>& outputType,
1650          at::Device device) {
1651         return computeOneOperand(
1652             "aten_frac",
1653             inputs,
1654             outputShape,
1655             outputStrides,
1656             outputType,
1657             [](const ExprHandle& a) {
1658               auto aa = promoteHalfToFloat(a);
1659               return aa - floor(aa);
1660             },
1661             kFloatingPointTypes);
1662       });
1663 
1664   RegisterNNCLoweringsFunction aten_lgamma(
1665       {"aten::lgamma(Tensor self) -> (Tensor)"},
1666       [](const std::vector<ArgValue>& inputs,
1667          const std::vector<ExprHandle>& outputShape,
1668          const std::vector<ExprHandle>& outputStrides,
1669          const std::optional<ScalarType>& outputType,
1670          at::Device device) {
1671         return computeOneOperand(
1672             "aten_lgamma",
1673             inputs,
1674             outputShape,
1675             outputStrides,
1676             outputType,
1677             [](const ExprHandle& a) {
1678               return lgamma(promoteIntegerToDefaultType(a));
1679             });
1680       });
1681 
1682   // TODO: convert to schema, add a test
1683   // RegisterNNCLoweringsFunction aten_rand_like(
1684   //     {"aten::rand_like"},
1685   //     [](const std::vector<ArgValue>& inputs,
1686   //        const std::vector<ExprHandle>& outputShape,
1687   //        const std::optional<ScalarType>& outputType,
1688   //        at::Device device) {
1689   //       return computeOneOperand(
1690   //           "aten_rand_like",
1691   //           inputs,
1692   //           outputShape,
1693   //           outputType,
1694   //           [](const ExprHandle& a) {
1695   //             return Intrinsics::make(IntrinsicsOp::kRand, a.dtype());
1696   //           });
1697   //     });
1698 
1699   // TODO: convert to schema, add a test
1700   // RegisterNNCLoweringsFunction aten_slice(
1701   //     {"aten::slice"},
1702   //     [](const std::vector<ArgValue>& inputs,
1703   //        const std::vector<ExprHandle>& outputShape,
1704   //        const std::optional<ScalarType>& outputType,
1705   //        at::Device device) {
1706   //       return Compute(
1707   //           "aten_slice",
1708   //           outputShape,
1709   //           [&](const std::vector<VarHandle>& axes) {
1710   //             int64_t dim =
1711   //                 at::maybe_wrap_dim(std::get<int64_t>(inputs[1]),
1712   //                 axes.size());
1713   //             ExprHandle start = constant(inputs[2]);
1714   //             ExprHandle stride = constant(inputs[4]);
1715 
1716   //             std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
1717   //             newAxes[dim] = stride * newAxes[dim] + start;
1718   //             return tensorOrConstant(inputs[0], newAxes);
1719   //           });
1720   //     });
1721   RegisterNNCLoweringsFunction aten_unsqueeze(
1722       {"aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))"},
1723       [](const std::vector<ArgValue>& inputs,
1724          const std::vector<ExprHandle>& outputShape,
1725          const std::vector<ExprHandle>& outputStrides,
1726          const std::optional<ScalarType>& outputType,
1727          at::Device device) {
1728         return Compute(
1729             "aten_unsqueeze",
1730             outputShape,
1731             outputStrides,
1732             [&](const std::vector<VarHandle>& axes) {
1733               int64_t dim = std::get<int64_t>(inputs[1]);
1734               if (dim < 0) {
1735                 if (axes.empty()) {
1736                   throw malformed_input("axes are zero handling unsqueeze");
1737                 }
1738                 dim += axes.size();
1739               }
1740               // To construct an expression for an 'unsqueezed' tensor we need
1741               // to drop the DIM-th axis, i.e.
1742               //    unsqueezed_v[i,j,k,l] = v[i,j,l] # dim = 2 - drop index 'k'
1743               //                 0 1 2 3
1744               std::vector<ExprHandle> indices;
1745               int64_t i = 0;
1746               for (const auto& a : axes) {
1747                 if (i++ != dim) {
1748                   indices.emplace_back(a.node());
1749                 }
1750               }
1751 
1752               return broadcast(std::get<BufHandle>(inputs[0]), indices);
1753             });
1754       });
1755   RegisterNNCLoweringsFunction aten_t(
1756       {"aten::t(Tensor(a) self) -> (Tensor(a))"},
1757       [](const std::vector<ArgValue>& inputs,
1758          const std::vector<ExprHandle>& outputShape,
1759          const std::vector<ExprHandle>& outputStrides,
1760          const std::optional<ScalarType>& outputType,
1761          at::Device device) {
1762         return computeTranspose(
1763             {inputs[0], (int64_t)1, (int64_t)0},
1764             outputShape,
1765             outputStrides,
1766             outputType,
1767             device);
1768       });
1769   RegisterNNCLoweringsFunction aten_transpose(
1770       {"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))"},
1771       computeTranspose);
1772   RegisterNNCLoweringsFunction aten_permute(
1773       {"aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))"},
1774       [](const std::vector<ArgValue>& inputs,
1775          const std::vector<ExprHandle>& outputShape,
1776          const std::vector<ExprHandle>& outputStrides,
1777          const std::optional<ScalarType>& outputType,
1778          at::Device device) {
1779         auto A = std::get<BufHandle>(inputs[0]);
1780         // Trivial case of 0-dim tensors: just a copy of the input
1781         if (A.ndim() == 0) {
1782           auto tensor = Compute(
1783               "aten_permute",
1784               outputShape,
1785               outputStrides,
1786               [&](const std::vector<VarHandle>& axes) {
1787                 std::vector<ExprHandle> empty_indices;
1788                 return A.load(empty_indices);
1789               });
1790           if (A.node()->qscale()) {
1791             tensor.buf()->set_qscale(A.node()->qscale());
1792             tensor.buf()->set_qzero(A.node()->qzero());
1793           }
1794           return tensor;
1795         }
1796         auto permute_dims = std::get<IntList>(inputs[1]);
1797         auto tensor = Compute(
1798             "aten_permute",
1799             outputShape,
1800             [&](const std::vector<VarHandle>& axes) {
1801               std::vector<VarHandle> new_axes;
1802               new_axes.resize(axes.size());
1803               assert(permute_dims.size() == axes.size());
1804               for (unsigned i = 0; i < axes.size(); i++) {
1805                 auto new_dim = at::maybe_wrap_dim(permute_dims[i], A.ndim());
1806                 new_axes[new_dim] = axes[i];
1807               }
1808               return A.load(new_axes);
1809             });
1810         if (A.node()->qscale()) {
1811           tensor.buf()->set_qscale(A.node()->qscale());
1812           tensor.buf()->set_qzero(A.node()->qzero());
1813         }
1814         return tensor;
1815       });
1816   RegisterNNCLoweringsFunction aten_expand(
1817       {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
1818        "aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))"},
1819       computeExpand);
1820 
1821   // TODO: add a test
1822   RegisterNNCLoweringsFunction aten_flatten(
1823       {"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> (Tensor(a))"},
1824       computeFlatten);
1825   RegisterNNCLoweringsFunction aten_view(
1826       {"aten::reshape(Tensor(a) self, int[] shape) -> (Tensor(a))",
1827        "aten::reshape_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
1828        "aten::view(Tensor(a) self, int[] size) -> (Tensor(a))",
1829        "aten::view_as(Tensor(a) self, Tensor other) -> (Tensor(a))"},
1830       computeReshape);
1831 
1832   // aten::mm is a subset of aten::matmul where both inputs are rank 2
1833   RegisterNNCLoweringsFunction aten_matmul(
1834       {"aten::mm(Tensor self, Tensor mat2) -> (Tensor)",
1835        "aten::matmul(Tensor self, Tensor other) -> (Tensor)"},
1836       computeMatmul);
1837 
1838   RegisterNNCLoweringsFunction aten_cat(
1839       {"aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)"}, computeCat);
1840 
1841   RegisterNNCLoweringsFunction aten_sum(
1842       {"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)",
1843        "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
1844       computeSum);
1845 
1846   RegisterNNCLoweringsFunction aten_softmax(
1847       {"aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)"},
1848       [](const std::vector<ArgValue>& inputs,
1849          const std::vector<ExprHandle>& outputShape,
1850          const std::vector<ExprHandle>& outputStrides,
1851          const std::optional<ScalarType>& outputType,
1852          at::Device device) {
1853         return computeSoftmax(inputs, outputShape, outputStrides, false);
1854       });
1855 
1856   RegisterNNCLoweringsFunction aten_log_softmax(
1857       {"aten::log_softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)"},
1858       [](const std::vector<ArgValue>& inputs,
1859          const std::vector<ExprHandle>& outputShape,
1860          const std::vector<ExprHandle>& outputStrides,
1861          const std::optional<ScalarType>& outputType,
1862          at::Device device) {
1863         return computeSoftmax(inputs, outputShape, outputStrides, true);
1864       });
1865 
1866   RegisterNNCLoweringsFunction aten_conv1d(
1867       {"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> (Tensor)"},
1868       computeConv1d);
1869   RegisterNNCLoweringsFunction aten_conv2d(
1870       {"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=[1, 1], int[2] padding=[0, 0], int[2] dilation=[1, 1], int groups=1) -> (Tensor)"},
1871       computeConv2d);
1872 
1873   RegisterNNCLoweringsFunction aten_addmm(
1874       {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)"},
1875       computeAddMM);
1876 
1877   RegisterNNCLoweringsFunction aten_mean(
1878       {"aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)",
1879        "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
1880       computeMean);
1881   RegisterNNCLoweringsFunction aten_max_reduction(
1882       {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"},
1883       computeMax);
1884 
1885   RegisterNNCLoweringsFunction aten_adaptive_avg_pool2d(
1886       {"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)"},
1887       computeAdaptiveAvgPool2d);
1888 
1889   RegisterNNCLoweringsFunction aten_add(
1890       {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
1891        "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)"},
1892       [](const std::vector<ArgValue>& inputs,
1893          const std::vector<ExprHandle>& outputShape,
1894          const std::vector<ExprHandle>& outputStrides,
1895          const std::optional<ScalarType>& outputType,
1896          at::Device device) {
1897         auto add_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) {
1898           return boolToInteger(lhs) + boolToInteger(rhs);
1899         };
1900         TORCH_INTERNAL_ASSERT(
1901             inputs.size() == 2 || inputs.size() == 3,
1902             buildErrorMessage("Invalid number of input operands"));
1903         return (inputs.size() > 2) ? computeTwoOperandWithAlpha(
1904                                          "aten_add",
1905                                          inputs,
1906                                          outputShape,
1907                                          outputStrides,
1908                                          outputType,
1909                                          add_lambda)
1910                                    : computeTwoOperand(
1911                                          "aten_add",
1912                                          inputs,
1913                                          outputShape,
1914                                          outputStrides,
1915                                          outputType,
1916                                          add_lambda);
1917       });
1918   RegisterNNCLoweringsFunction aten_embedding(
1919       {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor"},
1920       computeEmbedding);
1921 
1922 #define NNC_QUANTIZATION_EXPR_QUANT 1
1923 #define NNC_QUANTIZATION_EXPR_DEQUANT 1
1924 
1925   RegisterNNCLoweringsFunction aten_quantize_per_tensor(
1926       {"aten::quantize_per_tensor(Tensor self, float scale, int zero_point, int dtype) -> (Tensor)",
1927        "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int dtype) -> (Tensor)",
1928        "aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, int dtype) -> (Tensor[])"},
1929 #if NNC_QUANTIZATION_EXPR_QUANT == 1
1930       computeQuantizePerTensor
1931 #else
1932       computeQuantizePerTensorExternalCall
1933 #endif
1934   );
1935 
1936   RegisterNNCLoweringsFunction aten_dequantize(
1937       {"aten::dequantize.self(Tensor self) -> (Tensor)"},
1938 #if NNC_QUANTIZATION_EXPR_DEQUANT == 1
1939       computeDequantize
1940 #else
1941       computeDequantizeExternalCall
1942 #endif
1943   );
1944   RegisterNNCLoweringsFunction quantized_conv1d(
1945       {"quantized::conv1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)"},
1946       computeQuantizedConv1d);
1947 
1948   RegisterNNCLoweringsFunction quantized_conv2d(
1949       {"quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)"},
1950       computeQuantizedConv2d);
1951 
1952   RegisterNNCLoweringsFunction quantized_conv2d_relu(
1953       {"quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)"},
1954       computeQuantizedConv2dRelu);
1955 
1956   RegisterNNCLoweringsFunction quantized_linear(
1957       {"quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)"},
1958       computeQuantizedLinear);
1959 
1960   RegisterNNCLoweringsFunction quantized_linear_relu(
1961       {"quantized::linear_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y)"},
1962       computeQuantizedLinear);
1963 
1964   RegisterNNCLoweringsFunction quantized_add(
1965       {"quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> (Tensor qc)"},
1966       computeQuantizedAdd);
1967 
1968   RegisterNNCLoweringsFunction quantized_mul(
1969       {"quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point) -> (Tensor qc)"},
1970       computeQuantizedMul);
1971 
1972   RegisterNNCLoweringsFunction quantized_mul_scalar(
1973       {"quantized::mul.Scalar(Tensor qa, Scalar b) -> (Tensor qc)"},
1974       computeQuantizedMulScalar);
1975 
1976   RegisterNNCLoweringsFunction quantized_conv2d_prepack(
1977       {"quantized::conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> (__torch__.torch.classes.quantized.Conv2dPackedParamsBase)"},
1978       computeQuantizedConv2dPrepack);
1979 
1980   RegisterNNCLoweringsFunction quantized_cat(
1981       {"quantized::cat(Tensor[] qx, int dim, float? scale, int? zero_point) -> (Tensor)"},
1982       computeQuantizedCat);
1983 
1984   RegisterNNCLoweringsFunction aten_upsample_nearest2d(
1985       {"aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)"},
1986       computeUpsampleNearest2dExternalCall);
1987 
1988   return 0;
1989 }
1990 } // namespace
1991 
getStandardLoweringFor(const std::string & schema_str)1992 NNCLoweringFunction getStandardLoweringFor(const std::string& schema_str) {
1993   C10_UNUSED static const int once = nnc_lowerings_lazy_registration();
1994   const auto& lowerings = getNNCLoweringRegistry();
1995   if (auto l = lowerings.find(parseSchema(schema_str))) {
1996     return *l;
1997   }
1998   return nullptr;
1999 }
2000 
2001 } // namespace torch::jit::tensorexpr
2002