1 #include <torch/csrc/jit/tensorexpr/expr.h>
2
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
5
6 namespace torch::jit::tensorexpr {
7
operator +(const ExprHandle & other) const8 ExprHandle ExprHandle::operator+(const ExprHandle& other) const {
9 return Add::make(*this, other);
10 }
11
operator -(const ExprHandle & other) const12 ExprHandle ExprHandle::operator-(const ExprHandle& other) const {
13 return Sub::make(*this, other);
14 }
15
operator *(const ExprHandle & other) const16 ExprHandle ExprHandle::operator*(const ExprHandle& other) const {
17 return Mul::make(*this, other);
18 }
19
operator /(const ExprHandle & other) const20 ExprHandle ExprHandle::operator/(const ExprHandle& other) const {
21 return Div::make(*this, other);
22 }
23
operator %(const ExprHandle & other) const24 ExprHandle ExprHandle::operator%(const ExprHandle& other) const {
25 return Mod::make(*this, other);
26 }
27
operator ==(const ExprHandle & other) const28 ExprHandle ExprHandle::operator==(const ExprHandle& other) const {
29 return CompareSelect::make(*this, other, CompareSelectOperation::kEQ);
30 }
31
operator !=(const ExprHandle & other) const32 ExprHandle ExprHandle::operator!=(const ExprHandle& other) const {
33 return CompareSelect::make(*this, other, CompareSelectOperation::kNE);
34 }
35
operator >(const ExprHandle & other) const36 ExprHandle ExprHandle::operator>(const ExprHandle& other) const {
37 return CompareSelect::make(*this, other, CompareSelectOperation::kGT);
38 }
39
operator >=(const ExprHandle & other) const40 ExprHandle ExprHandle::operator>=(const ExprHandle& other) const {
41 return CompareSelect::make(*this, other, CompareSelectOperation::kGE);
42 }
43
operator <(const ExprHandle & other) const44 ExprHandle ExprHandle::operator<(const ExprHandle& other) const {
45 return CompareSelect::make(*this, other, CompareSelectOperation::kLT);
46 }
47
operator <=(const ExprHandle & other) const48 ExprHandle ExprHandle::operator<=(const ExprHandle& other) const {
49 return CompareSelect::make(*this, other, CompareSelectOperation::kLE);
50 }
51
operator &&(const ExprHandle & other) const52 ExprHandle ExprHandle::operator&&(const ExprHandle& other) const {
53 if (!this->node()->dtype().is_integral()) {
54 throw unsupported_dtype();
55 }
56 return IfThenElse::make(
57 *this, other, ExprHandle(getImmediateByType(other.dtype(), 0)));
58 }
59
operator ||(const ExprHandle & other) const60 ExprHandle ExprHandle::operator||(const ExprHandle& other) const {
61 if (!this->node()->dtype().is_integral()) {
62 throw unsupported_dtype();
63 }
64 return IfThenElse::make(
65 *this, ExprHandle(getImmediateByType(other.dtype(), 1)), other);
66 }
67
operator &(const ExprHandle & other) const68 ExprHandle ExprHandle::operator&(const ExprHandle& other) const {
69 return And::make(*this, other);
70 }
71
operator |(const ExprHandle & other) const72 ExprHandle ExprHandle::operator|(const ExprHandle& other) const {
73 return Or::make(*this, other);
74 }
75
operator ^(const ExprHandle & other) const76 ExprHandle ExprHandle::operator^(const ExprHandle& other) const {
77 return Xor::make(*this, other);
78 }
79
operator <<(const ExprHandle & other) const80 ExprHandle ExprHandle::operator<<(const ExprHandle& other) const {
81 return Lshift::make(*this, other);
82 }
83
operator >>(const ExprHandle & other) const84 ExprHandle ExprHandle::operator>>(const ExprHandle& other) const {
85 return Rshift::make(*this, other);
86 }
87
88 #define IMM_EXPR_DECLARE(Type, Name) \
89 ExprHandle::ExprHandle(Type v) : ExprHandle(Name##Imm::make(v)) {}
90 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
91 #undef IMM_EXPR_DECLARE
92
sin(const ExprHandle & v)93 ExprHandle sin(const ExprHandle& v) {
94 return Intrinsics::make(kSin, v);
95 }
96
cos(const ExprHandle & v)97 ExprHandle cos(const ExprHandle& v) {
98 return Intrinsics::make(kCos, v);
99 }
100
tan(const ExprHandle & v)101 ExprHandle tan(const ExprHandle& v) {
102 return Intrinsics::make(kTan, v);
103 }
104
asin(const ExprHandle & v)105 ExprHandle asin(const ExprHandle& v) {
106 return Intrinsics::make(kAsin, v);
107 }
108
acos(const ExprHandle & v)109 ExprHandle acos(const ExprHandle& v) {
110 return Intrinsics::make(kAcos, v);
111 }
112
atan(const ExprHandle & v)113 ExprHandle atan(const ExprHandle& v) {
114 return Intrinsics::make(kAtan, v);
115 }
116
sinh(const ExprHandle & v)117 ExprHandle sinh(const ExprHandle& v) {
118 return Intrinsics::make(kSinh, v);
119 }
120
cosh(const ExprHandle & v)121 ExprHandle cosh(const ExprHandle& v) {
122 return Intrinsics::make(kCosh, v);
123 }
124
tanh(const ExprHandle & v)125 ExprHandle tanh(const ExprHandle& v) {
126 return Intrinsics::make(kTanh, v);
127 }
128
sigmoid(const ExprHandle & v)129 ExprHandle sigmoid(const ExprHandle& v) {
130 return Intrinsics::make(kSigmoid, v);
131 }
132
exp(const ExprHandle & v)133 ExprHandle exp(const ExprHandle& v) {
134 return Intrinsics::make(kExp, v);
135 }
136
expm1(const ExprHandle & v)137 ExprHandle expm1(const ExprHandle& v) {
138 return Intrinsics::make(kExpm1, v);
139 }
140
abs(const ExprHandle & v)141 ExprHandle abs(const ExprHandle& v) {
142 return Intrinsics::make(kAbs, v);
143 }
144
145 // The default tanh is quite slow, use the Eigen version from here:
146 // https://bitbucket.org/eigen/eigen/src/94875feeeeb9abe5509b314197da1991ba2070f5/Eigen/src/Core/MathFunctionsImpl.h#lines-26
fast_tanh(const ExprHandle & v)147 ExprHandle fast_tanh(const ExprHandle& v) {
148 // TODO: use a dedicated bind-var to make sure v is not evaluated multiple
149 // times. Clamp the input expression to [-9, 9]
150 ExprHandle plus_9 = FloatImm::make(9.0f);
151 ExprHandle minus_9 = FloatImm::make(-9.0f);
152 ExprHandle v1 = Min::make(v, plus_9, false);
153 v1 = Max::make(v1, minus_9, false);
154
155 // The coefficients for the numerator
156 ExprHandle alpha_1 = FloatImm::make(4.89352455891786e-03f);
157 ExprHandle alpha_3 = FloatImm::make(6.37261928875436e-04f);
158 ExprHandle alpha_5 = FloatImm::make(1.48572235717979e-05f);
159 ExprHandle alpha_7 = FloatImm::make(5.12229709037114e-08f);
160 ExprHandle alpha_9 = FloatImm::make(-8.60467152213735e-11f);
161 ExprHandle alpha_11 = FloatImm::make(2.00018790482477e-13f);
162 ExprHandle alpha_13 = FloatImm::make(-2.76076847742355e-16f);
163
164 // The coefficients for the denominator
165 ExprHandle beta_0 = FloatImm::make(4.89352518554385e-03f);
166 ExprHandle beta_2 = FloatImm::make(2.26843463243900e-03f);
167 ExprHandle beta_4 = FloatImm::make(1.18534705686654e-04f);
168 ExprHandle beta_6 = FloatImm::make(1.19825839466702e-06f);
169
170 // numerator
171 ExprHandle v2 = v1 * v1;
172 ExprHandle p = v2 * alpha_13 + alpha_11;
173 p = v2 * p + alpha_9;
174 p = v2 * p + alpha_7;
175 p = v2 * p + alpha_5;
176 p = v2 * p + alpha_3;
177 p = v2 * p + alpha_1;
178 p = v1 * p;
179
180 // denominator
181 ExprHandle q = v2 * beta_6 + beta_4;
182 q = v2 * q + beta_2;
183 q = v2 * q + beta_0;
184
185 ExprHandle result = p / q;
186 return result;
187 }
188
fast_sigmoid(const ExprHandle & x)189 ExprHandle fast_sigmoid(const ExprHandle& x) {
190 // sigmoid(x) = (tanh(x / 2) + 1) / 2
191 ExprHandle one_v = FloatImm::make(1.f);
192 ExprHandle half_v = FloatImm::make(0.5f);
193 ExprHandle zero_v = FloatImm::make(0.0f);
194 ExprHandle x2 = x * half_v;
195 ExprHandle y{fast_tanh(x2)};
196 ExprHandle z = (y + one_v) * half_v;
197 // fast_tanh is not precise
198 // but clients rely on the sigmoid return values being probability-like
199 // so clamp them into (0, 1)
200 return Min::make(
201 one_v,
202 Max::make(zero_v, z, /* propagate_nans= */ false),
203 /* propagate_nans= */ false);
204 }
205
fast_log(const ExprHandle & v)206 ExprHandle fast_log(const ExprHandle& v) {
207 // this implementation is taken from sleef:
208 // https://github.com/shibatch/sleef/blob/master/src/libm/sleefsp.c#L1131
209 // to generate coefficients, this tool is provided
210 // https://github.com/shibatch/sleef/blob/master/src/gencoef/gencoef.txt
211 auto ilogb2kf = [](const ExprHandle& x) {
212 auto y = (bitcast<int32_t>(x) >> IntImm::make(23)) & IntImm::make(0xff);
213 return y - IntImm::make(0x7f);
214 };
215
216 auto ldexp3kf = [](const ExprHandle& x, const ExprHandle& e) {
217 return bitcast<float>(bitcast<int32_t>(x) + (e << IntImm::make(23)));
218 };
219 auto e = ilogb2kf(v * FloatImm::make(1.0 / 0.75));
220 auto m = ldexp3kf(v, IntImm::make(-1) * e);
221 auto one = FloatImm::make(1.0f);
222 auto x = (m - one) / (m + one);
223 auto x2 = x * x;
224
225 auto mlaf = [](const ExprHandle& x, const ExprHandle& y, float z) {
226 return x * y + FloatImm::make(z);
227 };
228
229 auto t = FloatImm::make(0.2392828464508056640625);
230 t = mlaf(t, x2, 0.28518211841583251953125);
231 t = mlaf(t, x2, 0.400005877017974853515625);
232 t = mlaf(t, x2, 0.666666686534881591796875);
233 t = mlaf(t, x2, 2.0);
234 x = x * t + FloatImm::make(0.693147180559945286226764) * e;
235
236 auto zero = FloatImm::make(0);
237 auto nan = FloatImm::make(std::numeric_limits<float>::quiet_NaN());
238 auto neg_inf = FloatImm::make(-std::numeric_limits<float>::infinity());
239 x = CompareSelect::make(v, zero, nan, x, kLT);
240 x = CompareSelect::make(v, zero, neg_inf, x, kEQ);
241 return x;
242 }
243
log_vml(const ExprHandle & v)244 ExprHandle log_vml(const ExprHandle& v) {
245 auto mlaf = [](const ExprHandle& x, const ExprHandle& y, float z) {
246 return x * y + FloatImm::make(z);
247 };
248
249 auto in = bitcast<int32_t>(v);
250 auto a = in - IntImm::make(0x3f2aaaab);
251 auto e = cast<float>(a >> IntImm::make(23));
252
253 auto x = (a & IntImm::make(0x7fffff)) + IntImm::make(0x3f2aaaab);
254 x = bitcast<float>(x) - 1.0f;
255
256 auto t = FloatImm::make(-0.12891686f);
257 t = mlaf(x, t, 0.139844373f);
258 t = mlaf(x, t, -0.121842608f);
259 t = mlaf(x, t, 0.140058696f);
260 t = mlaf(x, t, -0.16680488f);
261 t = mlaf(x, t, 0.200104058f);
262 t = mlaf(x, t, -0.249997973f);
263 t = mlaf(x, t, 0.333332151f);
264 t = mlaf(x, t, -0.5f);
265 t = x * t;
266 t = x * t + x;
267
268 auto z = e * FloatImm::make(1.42860677e-06f) + t;
269 z = e * FloatImm::make(0.693145752f) + z;
270
271 return CompareSelect::make(
272 IntImm::make(0x1000000),
273 in + IntImm::make(0x800000),
274 log(v),
275 z,
276 kGT,
277 kUnlikely);
278 }
279
log(const ExprHandle & v)280 ExprHandle log(const ExprHandle& v) {
281 return Intrinsics::make(kLog, v);
282 }
283
log2(const ExprHandle & v)284 ExprHandle log2(const ExprHandle& v) {
285 return Intrinsics::make(kLog2, v);
286 }
287
log10(const ExprHandle & v)288 ExprHandle log10(const ExprHandle& v) {
289 return Intrinsics::make(kLog10, v);
290 }
291
log1p(const ExprHandle & v)292 ExprHandle log1p(const ExprHandle& v) {
293 return Intrinsics::make(kLog1p, v);
294 }
295
erf(const ExprHandle & v)296 ExprHandle erf(const ExprHandle& v) {
297 return Intrinsics::make(kErf, v);
298 }
299
erfc(const ExprHandle & v)300 ExprHandle erfc(const ExprHandle& v) {
301 return Intrinsics::make(kErfc, v);
302 }
303
sqrt(const ExprHandle & v)304 ExprHandle sqrt(const ExprHandle& v) {
305 return Intrinsics::make(kSqrt, v);
306 }
307
rsqrt(const ExprHandle & v)308 ExprHandle rsqrt(const ExprHandle& v) {
309 return Intrinsics::make(kRsqrt, v);
310 }
311
ceil(const ExprHandle & v)312 ExprHandle ceil(const ExprHandle& v) {
313 return Intrinsics::make(kCeil, v);
314 }
315
floor(const ExprHandle & v)316 ExprHandle floor(const ExprHandle& v) {
317 return Intrinsics::make(kFloor, v);
318 }
319
round(const ExprHandle & v)320 ExprHandle round(const ExprHandle& v) {
321 return Intrinsics::make(kRound, v);
322 }
323
trunc(const ExprHandle & v)324 ExprHandle trunc(const ExprHandle& v) {
325 return Intrinsics::make(kTrunc, v);
326 }
327
frac(const ExprHandle & v)328 ExprHandle frac(const ExprHandle& v) {
329 return Intrinsics::make(kFrac, v);
330 }
331
lgamma(const ExprHandle & v)332 ExprHandle lgamma(const ExprHandle& v) {
333 return Intrinsics::make(kLgamma, v);
334 }
335
atan2(const ExprHandle & v1,const ExprHandle & v2)336 ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2) {
337 return Intrinsics::make(kAtan2, v1, v2);
338 }
339
pow(const ExprHandle & v1,const ExprHandle & v2)340 ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2) {
341 return Intrinsics::make(kPow, v1, v2);
342 }
343
fmod(const ExprHandle & v1,const ExprHandle & v2)344 ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2) {
345 return Intrinsics::make(kFmod, v1, v2);
346 }
347
remainder(const ExprHandle & v1,const ExprHandle & v2)348 ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2) {
349 return Intrinsics::make(kRemainder, v1, v2);
350 }
351
isnan(const ExprHandle & v1)352 ExprHandle isnan(const ExprHandle& v1) {
353 return Intrinsics::make(kIsNan, v1);
354 }
355
ifThenElse(const ExprHandle & c,const ExprHandle & t,const ExprHandle & f)356 ExprHandle ifThenElse(
357 const ExprHandle& c,
358 const ExprHandle& t,
359 const ExprHandle& f) {
360 return IfThenElse::make(c, t, f);
361 }
362
make_contiguous_strides(const std::vector<ExprHandle> & dims)363 std::vector<ExprPtr> make_contiguous_strides(
364 const std::vector<ExprHandle>& dims) {
365 std::vector<ExprPtr> strides;
366
367 if (!dims.empty()) {
368 strides.resize(dims.size());
369 auto si = immLike(dims[0], 1);
370 for (int64_t i = dims.size() - 1; i >= 0; --i) {
371 strides[i] = si;
372 si = alloc<Mul>(si, dims[i].node());
373 }
374 }
375 return strides;
376 }
377
make_channels_last_strides(const std::vector<ExprHandle> & dims)378 std::vector<ExprPtr> make_channels_last_strides(
379 const std::vector<ExprHandle>& dims) {
380 std::vector<ExprPtr> strides;
381 TORCH_INTERNAL_ASSERT(
382 dims.size() == 4 || dims.size() == 3, "got size:", dims.size());
383 if (dims.size() == 4) {
384 strides.resize(dims.size());
385 ExprHandle handle = ExprHandle(immLike(dims[0], 1));
386 // dims: n c h w
387 // strides(nhwc): w*c*h 1 w*c c
388 strides[1] = handle.node();
389 handle = handle * dims[1];
390 strides[3] = handle.node();
391 handle = handle * dims[3];
392 strides[2] = handle.node();
393 handle = handle * dims[2];
394 strides[0] = handle.node();
395 }
396 if (dims.size() == 3) {
397 strides.resize(dims.size());
398 ExprHandle handle = ExprHandle(immLike(dims[0], 1));
399 // dims: n c l
400 // strides(nlc): c*l 1 c
401 strides[1] = handle.node();
402 handle = handle * dims[1];
403 strides[2] = handle.node();
404 handle = handle * dims[2];
405 strides[0] = handle.node();
406 }
407 return strides;
408 }
409
Buf(const VarPtr & var,std::vector<ExprPtr> dims,Dtype dtype,ExprPtr initializer,std::optional<std::vector<ExprPtr>> strides,ExprPtr qscale,ExprPtr qzero)410 Buf::Buf(
411 const VarPtr& var,
412 std::vector<ExprPtr> dims,
413 Dtype dtype,
414 ExprPtr initializer,
415 std::optional<std::vector<ExprPtr>> strides,
416 ExprPtr qscale,
417 ExprPtr qzero)
418 : ExprNodeBase(dtype, kPrimitive),
419 base_handle_(var),
420 dims_(std::move(dims)),
421 strides_(
422 strides
423 ? *strides
424 : make_contiguous_strides(ExprVectorToExprHandleVector(dims_))),
425 initializer_(std::move(initializer)),
426 qscale_(std::move(qscale)),
427 qzero_(std::move(qzero)) {
428 TORCH_CHECK(var);
429 }
430
make(const std::vector<ExprHandle> & dims,Dtype dtype)431 BufHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
432 return Buf::make("", dims, dtype);
433 }
434
make(const std::string & name_hint,const std::vector<ExprHandle> & dims,const std::vector<ExprHandle> & strides,Dtype dtype)435 BufHandle Buf::make(
436 const std::string& name_hint,
437 const std::vector<ExprHandle>& dims,
438 const std::vector<ExprHandle>& strides,
439 Dtype dtype) {
440 return BufHandle(alloc<Buf>(
441 name_hint,
442 ExprHandleVectorToExprVector(dims),
443 dtype,
444 nullptr,
445 ExprHandleVectorToExprVector(strides)));
446 }
447
make(const std::string & name_hint,const std::vector<ExprHandle> & dims,Dtype dtype,std::optional<ExprHandle> initializer,const std::optional<std::vector<ExprHandle>> & strides,std::optional<ExprHandle> qscale,std::optional<ExprHandle> qzero)448 BufHandle Buf::make(
449 const std::string& name_hint,
450 const std::vector<ExprHandle>& dims,
451 Dtype dtype,
452 std::optional<ExprHandle> initializer,
453 const std::optional<std::vector<ExprHandle>>& strides,
454 std::optional<ExprHandle> qscale,
455 std::optional<ExprHandle> qzero) {
456 std::optional<std::vector<ExprPtr>> opt_strides;
457 if (strides) {
458 opt_strides = ExprHandleVectorToExprVector(*strides);
459 }
460 return BufHandle(alloc<Buf>(
461 name_hint,
462 ExprHandleVectorToExprVector(dims),
463 dtype,
464 initializer ? initializer->node() : nullptr,
465 opt_strides,
466 qscale ? qscale->node() : nullptr,
467 qzero ? qzero->node() : nullptr));
468 }
469
is_contiguous(at::MemoryFormat memory_format) const470 bool Buf::is_contiguous(at::MemoryFormat memory_format) const {
471 auto ndims = dims_.size();
472 std::vector<int64_t> dim_order(ndims);
473 if (memory_format == at::MemoryFormat::ChannelsLast) {
474 if (dims_.size() != 4)
475 return false;
476 dim_order = {1, 3, 2, 0};
477 } else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
478 if (dims_.size() != 5)
479 return false;
480 dim_order = {1, 4, 3, 2, 0};
481 } else {
482 if (dims_.empty()) {
483 // Scalar tensor
484 TORCH_CHECK(strides_.empty());
485 return true; // Align with the isContiguous logic in the kernel.cpp
486 }
487 for (size_t i = 0; i < ndims; i++) {
488 dim_order[i] = ndims - i - 1; // Reverse
489 }
490 }
491
492 bool res = is_stride_one(dim_order[0]);
493 if (!res)
494 return false;
495
496 for (size_t i = 1; i < ndims; i++) {
497 auto cur_dim = dim_order[i];
498 auto pre_dim = dim_order[i - 1];
499 res &= is_cont_with(cur_dim, pre_dim);
500 if (!res)
501 return false;
502 }
503
504 return true;
505 }
506
dims() const507 std::vector<ExprHandle> BufHandle::dims() const {
508 return ExprVectorToExprHandleVector(node()->dims());
509 }
510
is_cont_with(int cur_dim,int adjacent_dim) const511 bool Buf::is_cont_with(int cur_dim, int adjacent_dim) const {
512 auto is_cont_fn = [](const ExprPtr& adjacent_dim,
513 const ExprPtr& adjacent_stride,
514 const ExprPtr& cur_stride) {
515 // For static shape
516 bool res = exprEquals(
517 cur_stride,
518 (ExprHandle(adjacent_dim) * ExprHandle(adjacent_stride)).node());
519 if (res)
520 return res;
521
522 // For symbolic shape
523 auto mul_node = to<Mul>(cur_stride);
524 if (!mul_node) {
525 return false;
526 }
527
528 // lhs and rhs could be other dim or stride
529 auto lhs_ = mul_node->lhs();
530 auto rhs_ = mul_node->rhs();
531
532 bool same_stride = false;
533 auto same_dim = exprEquals(lhs_, adjacent_dim) || (adjacent_dim == lhs_);
534 if (same_dim) {
535 // lhs_ is dim while rhs_ is stride
536 same_stride =
537 exprEquals(rhs_, adjacent_stride) || (adjacent_stride == rhs_);
538 } else {
539 // lhs_ is stride while rhs_ is dim
540 same_dim = exprEquals(rhs_, adjacent_dim) || (adjacent_dim == rhs_);
541 same_stride =
542 exprEquals(lhs_, adjacent_stride) || (adjacent_stride == lhs_);
543 }
544
545 return same_dim && same_stride;
546 };
547 return is_cont_fn(
548 dims_[adjacent_dim], strides_[adjacent_dim], strides_[cur_dim]);
549 }
550
is_stride_one(int cur_dim) const551 bool Buf::is_stride_one(int cur_dim) const {
552 return exprEquals(strides_[cur_dim], alloc<LongImm>(1));
553 }
554
expr_to_vec(ExprHandle v,int lanes)555 ExprHandle expr_to_vec(ExprHandle v, int lanes) {
556 if (lanes == 1) {
557 return v;
558 } else {
559 return Broadcast::make(v, lanes);
560 }
561 }
562
563 } // namespace torch::jit::tensorexpr
564