xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/expr.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/expr.h>
2 
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5 
6 namespace torch::jit::tensorexpr {
7 
operator +(const ExprHandle & other) const8 ExprHandle ExprHandle::operator+(const ExprHandle& other) const {
9   return Add::make(*this, other);
10 }
11 
operator -(const ExprHandle & other) const12 ExprHandle ExprHandle::operator-(const ExprHandle& other) const {
13   return Sub::make(*this, other);
14 }
15 
operator *(const ExprHandle & other) const16 ExprHandle ExprHandle::operator*(const ExprHandle& other) const {
17   return Mul::make(*this, other);
18 }
19 
operator /(const ExprHandle & other) const20 ExprHandle ExprHandle::operator/(const ExprHandle& other) const {
21   return Div::make(*this, other);
22 }
23 
operator %(const ExprHandle & other) const24 ExprHandle ExprHandle::operator%(const ExprHandle& other) const {
25   return Mod::make(*this, other);
26 }
27 
operator ==(const ExprHandle & other) const28 ExprHandle ExprHandle::operator==(const ExprHandle& other) const {
29   return CompareSelect::make(*this, other, CompareSelectOperation::kEQ);
30 }
31 
operator !=(const ExprHandle & other) const32 ExprHandle ExprHandle::operator!=(const ExprHandle& other) const {
33   return CompareSelect::make(*this, other, CompareSelectOperation::kNE);
34 }
35 
operator >(const ExprHandle & other) const36 ExprHandle ExprHandle::operator>(const ExprHandle& other) const {
37   return CompareSelect::make(*this, other, CompareSelectOperation::kGT);
38 }
39 
operator >=(const ExprHandle & other) const40 ExprHandle ExprHandle::operator>=(const ExprHandle& other) const {
41   return CompareSelect::make(*this, other, CompareSelectOperation::kGE);
42 }
43 
operator <(const ExprHandle & other) const44 ExprHandle ExprHandle::operator<(const ExprHandle& other) const {
45   return CompareSelect::make(*this, other, CompareSelectOperation::kLT);
46 }
47 
operator <=(const ExprHandle & other) const48 ExprHandle ExprHandle::operator<=(const ExprHandle& other) const {
49   return CompareSelect::make(*this, other, CompareSelectOperation::kLE);
50 }
51 
operator &&(const ExprHandle & other) const52 ExprHandle ExprHandle::operator&&(const ExprHandle& other) const {
53   if (!this->node()->dtype().is_integral()) {
54     throw unsupported_dtype();
55   }
56   return IfThenElse::make(
57       *this, other, ExprHandle(getImmediateByType(other.dtype(), 0)));
58 }
59 
operator ||(const ExprHandle & other) const60 ExprHandle ExprHandle::operator||(const ExprHandle& other) const {
61   if (!this->node()->dtype().is_integral()) {
62     throw unsupported_dtype();
63   }
64   return IfThenElse::make(
65       *this, ExprHandle(getImmediateByType(other.dtype(), 1)), other);
66 }
67 
operator &(const ExprHandle & other) const68 ExprHandle ExprHandle::operator&(const ExprHandle& other) const {
69   return And::make(*this, other);
70 }
71 
operator |(const ExprHandle & other) const72 ExprHandle ExprHandle::operator|(const ExprHandle& other) const {
73   return Or::make(*this, other);
74 }
75 
operator ^(const ExprHandle & other) const76 ExprHandle ExprHandle::operator^(const ExprHandle& other) const {
77   return Xor::make(*this, other);
78 }
79 
operator <<(const ExprHandle & other) const80 ExprHandle ExprHandle::operator<<(const ExprHandle& other) const {
81   return Lshift::make(*this, other);
82 }
83 
operator >>(const ExprHandle & other) const84 ExprHandle ExprHandle::operator>>(const ExprHandle& other) const {
85   return Rshift::make(*this, other);
86 }
87 
88 #define IMM_EXPR_DECLARE(Type, Name) \
89   ExprHandle::ExprHandle(Type v) : ExprHandle(Name##Imm::make(v)) {}
90 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
91 #undef IMM_EXPR_DECLARE
92 
sin(const ExprHandle & v)93 ExprHandle sin(const ExprHandle& v) {
94   return Intrinsics::make(kSin, v);
95 }
96 
cos(const ExprHandle & v)97 ExprHandle cos(const ExprHandle& v) {
98   return Intrinsics::make(kCos, v);
99 }
100 
tan(const ExprHandle & v)101 ExprHandle tan(const ExprHandle& v) {
102   return Intrinsics::make(kTan, v);
103 }
104 
asin(const ExprHandle & v)105 ExprHandle asin(const ExprHandle& v) {
106   return Intrinsics::make(kAsin, v);
107 }
108 
acos(const ExprHandle & v)109 ExprHandle acos(const ExprHandle& v) {
110   return Intrinsics::make(kAcos, v);
111 }
112 
atan(const ExprHandle & v)113 ExprHandle atan(const ExprHandle& v) {
114   return Intrinsics::make(kAtan, v);
115 }
116 
sinh(const ExprHandle & v)117 ExprHandle sinh(const ExprHandle& v) {
118   return Intrinsics::make(kSinh, v);
119 }
120 
cosh(const ExprHandle & v)121 ExprHandle cosh(const ExprHandle& v) {
122   return Intrinsics::make(kCosh, v);
123 }
124 
tanh(const ExprHandle & v)125 ExprHandle tanh(const ExprHandle& v) {
126   return Intrinsics::make(kTanh, v);
127 }
128 
sigmoid(const ExprHandle & v)129 ExprHandle sigmoid(const ExprHandle& v) {
130   return Intrinsics::make(kSigmoid, v);
131 }
132 
exp(const ExprHandle & v)133 ExprHandle exp(const ExprHandle& v) {
134   return Intrinsics::make(kExp, v);
135 }
136 
expm1(const ExprHandle & v)137 ExprHandle expm1(const ExprHandle& v) {
138   return Intrinsics::make(kExpm1, v);
139 }
140 
abs(const ExprHandle & v)141 ExprHandle abs(const ExprHandle& v) {
142   return Intrinsics::make(kAbs, v);
143 }
144 
145 // The default tanh is quite slow, use the Eigen version from here:
146 // https://bitbucket.org/eigen/eigen/src/94875feeeeb9abe5509b314197da1991ba2070f5/Eigen/src/Core/MathFunctionsImpl.h#lines-26
fast_tanh(const ExprHandle & v)147 ExprHandle fast_tanh(const ExprHandle& v) {
148   // TODO: use a dedicated bind-var to make sure v is not evaluated multiple
149   // times. Clamp the input expression to [-9, 9]
150   ExprHandle plus_9 = FloatImm::make(9.0f);
151   ExprHandle minus_9 = FloatImm::make(-9.0f);
152   ExprHandle v1 = Min::make(v, plus_9, false);
153   v1 = Max::make(v1, minus_9, false);
154 
155   // The coefficients for the numerator
156   ExprHandle alpha_1 = FloatImm::make(4.89352455891786e-03f);
157   ExprHandle alpha_3 = FloatImm::make(6.37261928875436e-04f);
158   ExprHandle alpha_5 = FloatImm::make(1.48572235717979e-05f);
159   ExprHandle alpha_7 = FloatImm::make(5.12229709037114e-08f);
160   ExprHandle alpha_9 = FloatImm::make(-8.60467152213735e-11f);
161   ExprHandle alpha_11 = FloatImm::make(2.00018790482477e-13f);
162   ExprHandle alpha_13 = FloatImm::make(-2.76076847742355e-16f);
163 
164   // The coefficients for the denominator
165   ExprHandle beta_0 = FloatImm::make(4.89352518554385e-03f);
166   ExprHandle beta_2 = FloatImm::make(2.26843463243900e-03f);
167   ExprHandle beta_4 = FloatImm::make(1.18534705686654e-04f);
168   ExprHandle beta_6 = FloatImm::make(1.19825839466702e-06f);
169 
170   // numerator
171   ExprHandle v2 = v1 * v1;
172   ExprHandle p = v2 * alpha_13 + alpha_11;
173   p = v2 * p + alpha_9;
174   p = v2 * p + alpha_7;
175   p = v2 * p + alpha_5;
176   p = v2 * p + alpha_3;
177   p = v2 * p + alpha_1;
178   p = v1 * p;
179 
180   // denominator
181   ExprHandle q = v2 * beta_6 + beta_4;
182   q = v2 * q + beta_2;
183   q = v2 * q + beta_0;
184 
185   ExprHandle result = p / q;
186   return result;
187 }
188 
fast_sigmoid(const ExprHandle & x)189 ExprHandle fast_sigmoid(const ExprHandle& x) {
190   // sigmoid(x) = (tanh(x / 2) + 1) / 2
191   ExprHandle one_v = FloatImm::make(1.f);
192   ExprHandle half_v = FloatImm::make(0.5f);
193   ExprHandle zero_v = FloatImm::make(0.0f);
194   ExprHandle x2 = x * half_v;
195   ExprHandle y{fast_tanh(x2)};
196   ExprHandle z = (y + one_v) * half_v;
197   // fast_tanh is not precise
198   // but clients rely on the sigmoid return values being probability-like
199   // so clamp them into (0, 1)
200   return Min::make(
201       one_v,
202       Max::make(zero_v, z, /* propagate_nans= */ false),
203       /* propagate_nans= */ false);
204 }
205 
fast_log(const ExprHandle & v)206 ExprHandle fast_log(const ExprHandle& v) {
207   // this implementation is taken from sleef:
208   // https://github.com/shibatch/sleef/blob/master/src/libm/sleefsp.c#L1131
209   // to generate coefficients, this tool is provided
210   // https://github.com/shibatch/sleef/blob/master/src/gencoef/gencoef.txt
211   auto ilogb2kf = [](const ExprHandle& x) {
212     auto y = (bitcast<int32_t>(x) >> IntImm::make(23)) & IntImm::make(0xff);
213     return y - IntImm::make(0x7f);
214   };
215 
216   auto ldexp3kf = [](const ExprHandle& x, const ExprHandle& e) {
217     return bitcast<float>(bitcast<int32_t>(x) + (e << IntImm::make(23)));
218   };
219   auto e = ilogb2kf(v * FloatImm::make(1.0 / 0.75));
220   auto m = ldexp3kf(v, IntImm::make(-1) * e);
221   auto one = FloatImm::make(1.0f);
222   auto x = (m - one) / (m + one);
223   auto x2 = x * x;
224 
225   auto mlaf = [](const ExprHandle& x, const ExprHandle& y, float z) {
226     return x * y + FloatImm::make(z);
227   };
228 
229   auto t = FloatImm::make(0.2392828464508056640625);
230   t = mlaf(t, x2, 0.28518211841583251953125);
231   t = mlaf(t, x2, 0.400005877017974853515625);
232   t = mlaf(t, x2, 0.666666686534881591796875);
233   t = mlaf(t, x2, 2.0);
234   x = x * t + FloatImm::make(0.693147180559945286226764) * e;
235 
236   auto zero = FloatImm::make(0);
237   auto nan = FloatImm::make(std::numeric_limits<float>::quiet_NaN());
238   auto neg_inf = FloatImm::make(-std::numeric_limits<float>::infinity());
239   x = CompareSelect::make(v, zero, nan, x, kLT);
240   x = CompareSelect::make(v, zero, neg_inf, x, kEQ);
241   return x;
242 }
243 
log_vml(const ExprHandle & v)244 ExprHandle log_vml(const ExprHandle& v) {
245   auto mlaf = [](const ExprHandle& x, const ExprHandle& y, float z) {
246     return x * y + FloatImm::make(z);
247   };
248 
249   auto in = bitcast<int32_t>(v);
250   auto a = in - IntImm::make(0x3f2aaaab);
251   auto e = cast<float>(a >> IntImm::make(23));
252 
253   auto x = (a & IntImm::make(0x7fffff)) + IntImm::make(0x3f2aaaab);
254   x = bitcast<float>(x) - 1.0f;
255 
256   auto t = FloatImm::make(-0.12891686f);
257   t = mlaf(x, t, 0.139844373f);
258   t = mlaf(x, t, -0.121842608f);
259   t = mlaf(x, t, 0.140058696f);
260   t = mlaf(x, t, -0.16680488f);
261   t = mlaf(x, t, 0.200104058f);
262   t = mlaf(x, t, -0.249997973f);
263   t = mlaf(x, t, 0.333332151f);
264   t = mlaf(x, t, -0.5f);
265   t = x * t;
266   t = x * t + x;
267 
268   auto z = e * FloatImm::make(1.42860677e-06f) + t;
269   z = e * FloatImm::make(0.693145752f) + z;
270 
271   return CompareSelect::make(
272       IntImm::make(0x1000000),
273       in + IntImm::make(0x800000),
274       log(v),
275       z,
276       kGT,
277       kUnlikely);
278 }
279 
log(const ExprHandle & v)280 ExprHandle log(const ExprHandle& v) {
281   return Intrinsics::make(kLog, v);
282 }
283 
log2(const ExprHandle & v)284 ExprHandle log2(const ExprHandle& v) {
285   return Intrinsics::make(kLog2, v);
286 }
287 
log10(const ExprHandle & v)288 ExprHandle log10(const ExprHandle& v) {
289   return Intrinsics::make(kLog10, v);
290 }
291 
log1p(const ExprHandle & v)292 ExprHandle log1p(const ExprHandle& v) {
293   return Intrinsics::make(kLog1p, v);
294 }
295 
erf(const ExprHandle & v)296 ExprHandle erf(const ExprHandle& v) {
297   return Intrinsics::make(kErf, v);
298 }
299 
erfc(const ExprHandle & v)300 ExprHandle erfc(const ExprHandle& v) {
301   return Intrinsics::make(kErfc, v);
302 }
303 
sqrt(const ExprHandle & v)304 ExprHandle sqrt(const ExprHandle& v) {
305   return Intrinsics::make(kSqrt, v);
306 }
307 
rsqrt(const ExprHandle & v)308 ExprHandle rsqrt(const ExprHandle& v) {
309   return Intrinsics::make(kRsqrt, v);
310 }
311 
ceil(const ExprHandle & v)312 ExprHandle ceil(const ExprHandle& v) {
313   return Intrinsics::make(kCeil, v);
314 }
315 
floor(const ExprHandle & v)316 ExprHandle floor(const ExprHandle& v) {
317   return Intrinsics::make(kFloor, v);
318 }
319 
round(const ExprHandle & v)320 ExprHandle round(const ExprHandle& v) {
321   return Intrinsics::make(kRound, v);
322 }
323 
trunc(const ExprHandle & v)324 ExprHandle trunc(const ExprHandle& v) {
325   return Intrinsics::make(kTrunc, v);
326 }
327 
frac(const ExprHandle & v)328 ExprHandle frac(const ExprHandle& v) {
329   return Intrinsics::make(kFrac, v);
330 }
331 
lgamma(const ExprHandle & v)332 ExprHandle lgamma(const ExprHandle& v) {
333   return Intrinsics::make(kLgamma, v);
334 }
335 
atan2(const ExprHandle & v1,const ExprHandle & v2)336 ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2) {
337   return Intrinsics::make(kAtan2, v1, v2);
338 }
339 
pow(const ExprHandle & v1,const ExprHandle & v2)340 ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2) {
341   return Intrinsics::make(kPow, v1, v2);
342 }
343 
fmod(const ExprHandle & v1,const ExprHandle & v2)344 ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2) {
345   return Intrinsics::make(kFmod, v1, v2);
346 }
347 
remainder(const ExprHandle & v1,const ExprHandle & v2)348 ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2) {
349   return Intrinsics::make(kRemainder, v1, v2);
350 }
351 
isnan(const ExprHandle & v1)352 ExprHandle isnan(const ExprHandle& v1) {
353   return Intrinsics::make(kIsNan, v1);
354 }
355 
ifThenElse(const ExprHandle & c,const ExprHandle & t,const ExprHandle & f)356 ExprHandle ifThenElse(
357     const ExprHandle& c,
358     const ExprHandle& t,
359     const ExprHandle& f) {
360   return IfThenElse::make(c, t, f);
361 }
362 
make_contiguous_strides(const std::vector<ExprHandle> & dims)363 std::vector<ExprPtr> make_contiguous_strides(
364     const std::vector<ExprHandle>& dims) {
365   std::vector<ExprPtr> strides;
366 
367   if (!dims.empty()) {
368     strides.resize(dims.size());
369     auto si = immLike(dims[0], 1);
370     for (int64_t i = dims.size() - 1; i >= 0; --i) {
371       strides[i] = si;
372       si = alloc<Mul>(si, dims[i].node());
373     }
374   }
375   return strides;
376 }
377 
make_channels_last_strides(const std::vector<ExprHandle> & dims)378 std::vector<ExprPtr> make_channels_last_strides(
379     const std::vector<ExprHandle>& dims) {
380   std::vector<ExprPtr> strides;
381   TORCH_INTERNAL_ASSERT(
382       dims.size() == 4 || dims.size() == 3, "got size:", dims.size());
383   if (dims.size() == 4) {
384     strides.resize(dims.size());
385     ExprHandle handle = ExprHandle(immLike(dims[0], 1));
386     // dims:               n   c    h  w
387     // strides(nhwc):  w*c*h   1  w*c  c
388     strides[1] = handle.node();
389     handle = handle * dims[1];
390     strides[3] = handle.node();
391     handle = handle * dims[3];
392     strides[2] = handle.node();
393     handle = handle * dims[2];
394     strides[0] = handle.node();
395   }
396   if (dims.size() == 3) {
397     strides.resize(dims.size());
398     ExprHandle handle = ExprHandle(immLike(dims[0], 1));
399     // dims:              n   c    l
400     // strides(nlc):    c*l   1    c
401     strides[1] = handle.node();
402     handle = handle * dims[1];
403     strides[2] = handle.node();
404     handle = handle * dims[2];
405     strides[0] = handle.node();
406   }
407   return strides;
408 }
409 
Buf(const VarPtr & var,std::vector<ExprPtr> dims,Dtype dtype,ExprPtr initializer,std::optional<std::vector<ExprPtr>> strides,ExprPtr qscale,ExprPtr qzero)410 Buf::Buf(
411     const VarPtr& var,
412     std::vector<ExprPtr> dims,
413     Dtype dtype,
414     ExprPtr initializer,
415     std::optional<std::vector<ExprPtr>> strides,
416     ExprPtr qscale,
417     ExprPtr qzero)
418     : ExprNodeBase(dtype, kPrimitive),
419       base_handle_(var),
420       dims_(std::move(dims)),
421       strides_(
422           strides
423               ? *strides
424               : make_contiguous_strides(ExprVectorToExprHandleVector(dims_))),
425       initializer_(std::move(initializer)),
426       qscale_(std::move(qscale)),
427       qzero_(std::move(qzero)) {
428   TORCH_CHECK(var);
429 }
430 
make(const std::vector<ExprHandle> & dims,Dtype dtype)431 BufHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
432   return Buf::make("", dims, dtype);
433 }
434 
make(const std::string & name_hint,const std::vector<ExprHandle> & dims,const std::vector<ExprHandle> & strides,Dtype dtype)435 BufHandle Buf::make(
436     const std::string& name_hint,
437     const std::vector<ExprHandle>& dims,
438     const std::vector<ExprHandle>& strides,
439     Dtype dtype) {
440   return BufHandle(alloc<Buf>(
441       name_hint,
442       ExprHandleVectorToExprVector(dims),
443       dtype,
444       nullptr,
445       ExprHandleVectorToExprVector(strides)));
446 }
447 
make(const std::string & name_hint,const std::vector<ExprHandle> & dims,Dtype dtype,std::optional<ExprHandle> initializer,const std::optional<std::vector<ExprHandle>> & strides,std::optional<ExprHandle> qscale,std::optional<ExprHandle> qzero)448 BufHandle Buf::make(
449     const std::string& name_hint,
450     const std::vector<ExprHandle>& dims,
451     Dtype dtype,
452     std::optional<ExprHandle> initializer,
453     const std::optional<std::vector<ExprHandle>>& strides,
454     std::optional<ExprHandle> qscale,
455     std::optional<ExprHandle> qzero) {
456   std::optional<std::vector<ExprPtr>> opt_strides;
457   if (strides) {
458     opt_strides = ExprHandleVectorToExprVector(*strides);
459   }
460   return BufHandle(alloc<Buf>(
461       name_hint,
462       ExprHandleVectorToExprVector(dims),
463       dtype,
464       initializer ? initializer->node() : nullptr,
465       opt_strides,
466       qscale ? qscale->node() : nullptr,
467       qzero ? qzero->node() : nullptr));
468 }
469 
is_contiguous(at::MemoryFormat memory_format) const470 bool Buf::is_contiguous(at::MemoryFormat memory_format) const {
471   auto ndims = dims_.size();
472   std::vector<int64_t> dim_order(ndims);
473   if (memory_format == at::MemoryFormat::ChannelsLast) {
474     if (dims_.size() != 4)
475       return false;
476     dim_order = {1, 3, 2, 0};
477   } else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
478     if (dims_.size() != 5)
479       return false;
480     dim_order = {1, 4, 3, 2, 0};
481   } else {
482     if (dims_.empty()) {
483       // Scalar tensor
484       TORCH_CHECK(strides_.empty());
485       return true; // Align with the isContiguous logic in the kernel.cpp
486     }
487     for (size_t i = 0; i < ndims; i++) {
488       dim_order[i] = ndims - i - 1; // Reverse
489     }
490   }
491 
492   bool res = is_stride_one(dim_order[0]);
493   if (!res)
494     return false;
495 
496   for (size_t i = 1; i < ndims; i++) {
497     auto cur_dim = dim_order[i];
498     auto pre_dim = dim_order[i - 1];
499     res &= is_cont_with(cur_dim, pre_dim);
500     if (!res)
501       return false;
502   }
503 
504   return true;
505 }
506 
dims() const507 std::vector<ExprHandle> BufHandle::dims() const {
508   return ExprVectorToExprHandleVector(node()->dims());
509 }
510 
is_cont_with(int cur_dim,int adjacent_dim) const511 bool Buf::is_cont_with(int cur_dim, int adjacent_dim) const {
512   auto is_cont_fn = [](const ExprPtr& adjacent_dim,
513                        const ExprPtr& adjacent_stride,
514                        const ExprPtr& cur_stride) {
515     // For static shape
516     bool res = exprEquals(
517         cur_stride,
518         (ExprHandle(adjacent_dim) * ExprHandle(adjacent_stride)).node());
519     if (res)
520       return res;
521 
522     // For symbolic shape
523     auto mul_node = to<Mul>(cur_stride);
524     if (!mul_node) {
525       return false;
526     }
527 
528     // lhs and rhs could be other dim or stride
529     auto lhs_ = mul_node->lhs();
530     auto rhs_ = mul_node->rhs();
531 
532     bool same_stride = false;
533     auto same_dim = exprEquals(lhs_, adjacent_dim) || (adjacent_dim == lhs_);
534     if (same_dim) {
535       // lhs_ is dim while rhs_ is stride
536       same_stride =
537           exprEquals(rhs_, adjacent_stride) || (adjacent_stride == rhs_);
538     } else {
539       // lhs_ is stride while rhs_ is dim
540       same_dim = exprEquals(rhs_, adjacent_dim) || (adjacent_dim == rhs_);
541       same_stride =
542           exprEquals(lhs_, adjacent_stride) || (adjacent_stride == lhs_);
543     }
544 
545     return same_dim && same_stride;
546   };
547   return is_cont_fn(
548       dims_[adjacent_dim], strides_[adjacent_dim], strides_[cur_dim]);
549 }
550 
is_stride_one(int cur_dim) const551 bool Buf::is_stride_one(int cur_dim) const {
552   return exprEquals(strides_[cur_dim], alloc<LongImm>(1));
553 }
554 
expr_to_vec(ExprHandle v,int lanes)555 ExprHandle expr_to_vec(ExprHandle v, int lanes) {
556   if (lanes == 1) {
557     return v;
558   } else {
559     return Broadcast::make(v, lanes);
560   }
561 }
562 
563 } // namespace torch::jit::tensorexpr
564