xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_expr.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/tensorexpr/test_base.h>
4 
5 #include <c10/util/irange.h>
6 #include <test/cpp/tensorexpr/padded_buffer.h>
7 #include <test/cpp/tensorexpr/test_utils.h>
8 #include <torch/csrc/jit/tensorexpr/eval.h>
9 #include <torch/csrc/jit/tensorexpr/ir.h>
10 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
11 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
12 #include <torch/csrc/jit/tensorexpr/ir_verifier.h>
13 #include <torch/csrc/jit/tensorexpr/loopnest.h>
14 #include <torch/csrc/jit/tensorexpr/tensor.h>
15 
16 #include <cmath>
17 #include <sstream>
18 #include <stdexcept>
19 #include <string>
20 #include <vector>
21 
22 namespace torch {
23 namespace jit {
24 using namespace torch::jit::tensorexpr;
25 
26 using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
27 
TEST(Expr,BasicValueTest)28 TEST(Expr, BasicValueTest) {
29   ExprHandle a = IntImm::make(2), b = IntImm::make(3);
30   ExprHandle c = Add::make(a, b);
31   SimpleIRExprEval eval(c);
32   ASSERT_EQ(eval.value<int>(), 5);
33 }
34 
TEST(Expr,BasicValueTest02)35 TEST(Expr, BasicValueTest02) {
36   ExprHandle a(2.0f);
37   ExprHandle b(3.0f);
38   ExprHandle c(4.0f);
39   ExprHandle d(5.0f);
40   ExprHandle f = (a + b) - (c + d);
41   SimpleIRExprEval eval(f);
42   ASSERT_EQ(eval.value<float>(), -4.0f);
43 }
44 
TEST(Expr,IsChannelsLastContiguous)45 TEST(Expr, IsChannelsLastContiguous) {
46   std::vector<VarHandle> vars = {
47       VarHandle("var1", kLong),
48       VarHandle("var2", kLong),
49       VarHandle("var3", kLong),
50       VarHandle("var4", kLong),
51       VarHandle("var5", kLong)};
52 
53   // {
54   //   key: ndims,
55   //   value: [
56   //     ...
57   //     [dim_2, dim_1, ..., dim_n]
58   //   ]
59   // }
60   using shapGenInfo = std::unordered_map<int, std::vector<std::vector<int>>>;
61 
62   // {
63   //   size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n],
64   //   strides: [
65   //     ...
66   //     [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z]
67   //   ]
68   // }
69   using shapeInfo =
70       std::pair<std::vector<ExprHandle>, std::vector<std::vector<ExprHandle>>>;
71 
72   std::vector<int> dims = {3, 4, 5};
73 
74   std::unordered_map<int, std::vector<ExprHandle>> dims_expr_vec_conf = {
75       {3, std::vector<ExprHandle>(vars.begin(), vars.begin() + 2)},
76       {4, std::vector<ExprHandle>(vars.begin(), vars.begin() + 3)},
77       {5, std::vector<ExprHandle>(vars.begin(), vars.begin() + 4)},
78   };
79 
80   shapGenInfo channels_last_cont_shape_conf = {
81       {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}};
82   shapGenInfo channels_last_non_cont_shape_conf = {
83       {3, {{2, 1, 0}, {1, 0, 2}}},
84       {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}},
85       {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}};
86 
87   shapGenInfo cont_shape_conf = {
88       {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}};
89 
90   auto shape_gen_fn = [dims_expr_vec_conf](
91                           int ndims, shapGenInfo shape_gen_info) -> shapeInfo {
92     auto dims_expr_vec = dims_expr_vec_conf.at(ndims);
93     std::vector<std::vector<ExprHandle>> strides_expr_vec;
94     for (size_t i = 0; i < strides_expr_vec.size(); i++) {
95       strides_expr_vec[i].resize(ndims);
96     }
97 
98     auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) {
99       if (indicator % 2 == 0) {
100         return a * b;
101       } else {
102         return b * a;
103       }
104     };
105 
106     auto stride_order_vec = shape_gen_info.at(ndims);
107     for (size_t i = 0; i < strides_expr_vec.size(); i++) {
108       auto stride_order = stride_order_vec[i];
109 
110       strides_expr_vec[i][stride_order[0]] = 1;
111       for (size_t j = 1; j < stride_order.size(); j++) {
112         auto cur_dim_idx = stride_order[j];
113         auto adjacent_dim_idx = stride_order[j - 1];
114 
115         strides_expr_vec[i][cur_dim_idx] = stride_gen_fn(
116             i,
117             dims_expr_vec[adjacent_dim_idx],
118             strides_expr_vec[i][adjacent_dim_idx]);
119       }
120     }
121 
122     return {dims_expr_vec, strides_expr_vec};
123   };
124 
125   auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool {
126     if (ndims == 3) {
127       return buf_handle.is_channels_last_1d_contiguous();
128     } else if (ndims == 4) {
129       return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast);
130     } else {
131       return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d);
132     }
133   };
134 
135   // channels-last contiguous
136   for (size_t i = 0; i < dims.size(); i++) {
137     auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
138     for (size_t j = 0; j < shape_info.second.size(); j++) {
139       BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
140       ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true);
141     }
142   }
143 
144   // channels-last non-contiguous
145   for (size_t i = 0; i < dims.size(); i++) {
146     auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf);
147     for (size_t j = 0; j < shape_info.second.size(); j++) {
148       BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
149       ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false);
150     }
151   }
152 
153   // contiguous
154   for (size_t i = 0; i < dims.size(); i++) {
155     auto shape_info = shape_gen_fn(dims[i], cont_shape_conf);
156     for (size_t j = 0; j < shape_info.second.size(); j++) {
157       BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
158       ASSERT_EQ(buf_handle.is_contiguous(), true);
159     }
160   }
161 
162   // non-contiguous
163   for (size_t i = 0; i < dims.size(); i++) {
164     auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
165     for (size_t j = 0; j < shape_info.second.size(); j++) {
166       BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
167       ASSERT_EQ(buf_handle.is_contiguous(), false);
168     }
169   }
170 }
171 
TEST(Expr,LetTest01)172 TEST(Expr, LetTest01) {
173   VarHandle x("x", kFloat);
174   ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
175   SimpleIRExprEval eval(body);
176   eval.bindVar(x, ExprHandle(3.f));
177   ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
178 }
179 
TEST(Expr,LetTest02)180 TEST(Expr, LetTest02) {
181   VarHandle x("x", kFloat);
182   VarHandle y("y", kFloat);
183   ExprHandle body =
184       ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
185   SimpleIRExprEval eval(body);
186   eval.bindVar(x, ExprHandle(3.f));
187   eval.bindVar(y, ExprHandle(6.f));
188   ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
189 }
190 
TEST(Expr,LetStmtTest01)191 TEST(Expr, LetStmtTest01) {
192   BufHandle a_buf("a", {1}, kFloat);
193   BufHandle b_buf("b", {1}, kFloat);
194 
195   ExprHandle load_a = a_buf.load(0);
196   VarHandle var = VarHandle("v", kFloat);
197   StmtPtr let_store = Let::make(var, load_a);
198   StmtPtr store_b = b_buf.store({0}, var);
199   BlockPtr block = Block::make({let_store, store_b});
200 
201   SimpleIREvaluator eval(block, {a_buf, b_buf});
202 
203   PaddedBuffer<float> a_v(1);
204   PaddedBuffer<float> b_v(1);
205   PaddedBuffer<float> b_ref(1);
206 
207   a_v(0) = 23;
208   b_ref(0) = a_v(0);
209   eval(a_v, b_v);
210 
211   ExpectAllNear(b_v, b_ref, 1e-5);
212 }
213 
TEST(Expr,IntTest)214 TEST(Expr, IntTest) {
215   VarHandle x("x", kInt);
216   ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
217   SimpleIRExprEval eval(body);
218   eval.bindVar(x, ExprHandle(3));
219   ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
220 }
221 
TEST(Expr,FloatTest)222 TEST(Expr, FloatTest) {
223   VarHandle x("x", kFloat);
224   ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
225   SimpleIRExprEval eval(body);
226   eval.bindVar(x, ExprHandle(3.f));
227   ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
228 }
229 
TEST(Expr,ByteTest)230 TEST(Expr, ByteTest) {
231   VarHandle x("x", kByte);
232   ExprHandle body = ExprHandle((uint8_t)2) +
233       (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
234   SimpleIRExprEval eval(body);
235   eval.bindVar(x, ExprHandle((uint8_t)3));
236   ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
237 }
238 
TEST(Expr,CharTest)239 TEST(Expr, CharTest) {
240   VarHandle x("x", kChar);
241   ExprHandle body = ExprHandle((int8_t)2) +
242       (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
243   SimpleIRExprEval eval(body);
244   eval.bindVar(x, ExprHandle((int8_t)3));
245   ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
246 }
247 
TEST(Expr,ShortTest)248 TEST(Expr, ShortTest) {
249   VarHandle x("x", kShort);
250   ExprHandle body = ExprHandle((int16_t)2) +
251       (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
252   SimpleIRExprEval eval(body);
253   eval.bindVar(x, ExprHandle((int16_t)3));
254   ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
255 }
256 
TEST(Expr,LongTest)257 TEST(Expr, LongTest) {
258   VarHandle x("x", kLong);
259   ExprHandle body = ExprHandle((int64_t)2) +
260       (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
261   SimpleIRExprEval eval(body);
262   eval.bindVar(x, ExprHandle((int64_t)3));
263   ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
264 }
265 
TEST(Expr,HalfTest)266 TEST(Expr, HalfTest) {
267   VarHandle x("x", kHalf);
268   ExprHandle body = ExprHandle((at::Half)2) +
269       (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
270   SimpleIRExprEval eval(body);
271   eval.bindVar(x, ExprHandle((at::Half)3));
272   ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
273 }
274 
TEST(Expr,DoubleTest)275 TEST(Expr, DoubleTest) {
276   VarHandle x("x", kDouble);
277   ExprHandle body = ExprHandle((double)2) +
278       (x * ExprHandle((double)3) + ExprHandle((double)4));
279   SimpleIRExprEval eval(body);
280   eval.bindVar(x, ExprHandle((double)3));
281   ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
282 }
283 
TEST(Expr,VectorAdd01)284 TEST(Expr, VectorAdd01) {
285   const int kVectorSize = 8;
286   const int kVectorCount = 128;
287   const int kTotalSize = kVectorSize * kVectorCount;
288 
289   BufHandle a_buf("A", {kTotalSize}, kFloat);
290   BufHandle b_buf("B", {kTotalSize}, kFloat);
291   BufHandle c_buf("C", {kTotalSize}, kFloat);
292 
293   /*
294   Build the following:
295     for (const auto index : c10::irange(kVectorCount)) {
296       store(c_buf, ramp(index * 8, 1, 8),
297             load(a_buf, ramp(index * 8, 1, 8) +
298             load(b_buf, ramp(index * 8, 1, 8))))
299     }
300   */
301   VarHandle index = VarHandle("index", kInt);
302   ExprHandle load_a =
303       a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
304   ExprHandle load_b =
305       b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
306   ExprHandle value = load_a + load_b;
307   StmtPtr store_c =
308       c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value);
309   StmtPtr stmt = For::make(index, 0, kVectorCount, store_c);
310 
311   ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
312   ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
313   ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
314 
315   PaddedBuffer<float> a_v(kTotalSize);
316   PaddedBuffer<float> b_v(kTotalSize);
317   PaddedBuffer<float> c_v(kTotalSize);
318   PaddedBuffer<float> c_ref(kTotalSize);
319   for (const auto i : c10::irange(kTotalSize)) {
320     a_v(i) = i * i;
321     b_v(i) = i * i * 4;
322     c_ref(i) = a_v(i) + b_v(i);
323   }
324   SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
325   ir_eval(a_v, b_v, c_v);
326   ExpectAllNear(c_v, c_ref, 1e-5);
327 }
328 
TEST(Expr,CompareSelectEQ)329 TEST(Expr, CompareSelectEQ) {
330   constexpr int N = 1024;
331   BufHandle a("A", {N}, kInt);
332   BufHandle b("B", {N}, kInt);
333   BufHandle c("C", {N}, kInt);
334   std::vector<int> a_buffer(N, 1);
335   std::vector<int> b_buffer(N, 1);
336   std::vector<int> c_buffer(N, 0);
337   std::vector<int> c_ref(N, 0);
338 
339   VarHandle i("i", kInt);
340   auto memcpy_expr = For::make(
341       i,
342       0,
343       N,
344       c.store(
345           {i},
346           CompareSelect::make(
347               a.load(i), b.load(i), CompareSelectOperation::kEQ)));
348 
349   SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
350   ir_eval(a_buffer, b_buffer, c_buffer);
351 
352   ASSERT_EQ(a_buffer.size(), N);
353   ASSERT_EQ(b_buffer.size(), N);
354   ASSERT_EQ(c_buffer.size(), N);
355 
356   assertAllEqual(a_buffer, 1);
357   assertAllEqual(b_buffer, 1);
358   assertAllEqual(c_buffer, 1);
359 }
360 
TEST(Expr,CompareSelectDtypes)361 TEST(Expr, CompareSelectDtypes) {
362   // LHS and RHS expressions should have the same dtype, but this dtype could
363   // differ from the dtype of the return values (but dtypes of true and false
364   // return values should be the same).
365   // This test constructs a CompareSelect expression where the input dtype is
366   // different from the output dtype and verifies that it works correctly:
367   //   result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2
368   constexpr int N = 1024;
369   BufHandle a("A", {N}, kInt);
370   BufHandle b("B", {N}, kInt);
371   BufHandle c("C", {N}, kFloat);
372   std::vector<int> a_buffer(N, 1);
373   std::vector<int> b_buffer(N, 1);
374   std::vector<float> c_buffer(N, 0.0f);
375   std::vector<float> c_ref(N, 3.14f);
376 
377   VarHandle i("i", kInt);
378   // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f
379   // A and B are int, C is float.
380   auto select_expr = For::make(
381       i,
382       0,
383       N,
384       c.store(
385           {i},
386           CompareSelect::make(
387               a.load(i),
388               b.load(i),
389               FloatImm::make(3.14f),
390               FloatImm::make(2.78f),
391               CompareSelectOperation::kEQ)));
392 
393   SimpleIREvaluator ir_eval(select_expr, {a, b, c});
394   ir_eval(a_buffer, b_buffer, c_buffer);
395 
396   ASSERT_EQ(a_buffer.size(), N);
397   ASSERT_EQ(b_buffer.size(), N);
398   ASSERT_EQ(c_buffer.size(), N);
399 
400   assertAllEqual(a_buffer, 1);
401   assertAllEqual(b_buffer, 1);
402   ExpectAllNear(c_buffer, c_ref, 1e-7);
403 }
404 
TEST(Expr,IntrinsicsDtypes)405 TEST(Expr, IntrinsicsDtypes) {
406   constexpr int N = 256;
407   BufHandle a("A", {N}, kDouble);
408   BufHandle b("B", {N}, kDouble);
409   std::vector<double> a_buffer(N, -10.0);
410   std::vector<double> b_buffer(N, 0.0);
411   std::vector<double> b_ref(N, 10.0);
412 
413   VarHandle i("i", kInt);
414   auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i))));
415 
416   SimpleIREvaluator ir_eval(abs_expr, {a, b});
417   ir_eval(a_buffer, b_buffer);
418 
419   ASSERT_EQ(a_buffer.size(), N);
420   ASSERT_EQ(b_buffer.size(), N);
421 
422   assertAllEqual(a_buffer, -10.0);
423   ExpectAllNear(b_buffer, b_ref, 1e-7);
424 }
425 
TEST(Expr,Substitute01)426 TEST(Expr, Substitute01) {
427   VarPtr x = alloc<Var>("x", kFloat);
428   VarPtr y = alloc<Var>("y", kFloat);
429   ExprPtr e =
430       alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y));
431 
432   VarPtr z = alloc<Var>("z", kFloat);
433   ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}});
434   ExprPtr e2_ref = alloc<Mul>(
435       alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)),
436       alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y));
437   std::ostringstream oss;
438   oss << *e2;
439   std::string e2_str = oss.str();
440 
441   oss.str("");
442   oss << *e2_ref;
443   std::string e2_ref_str = oss.str();
444   ASSERT_EQ(e2_str, e2_ref_str);
445 }
446 
TEST(Expr,Math01)447 TEST(Expr, Math01) {
448   ExprHandle v = sin(ExprHandle(1.0f));
449 
450   std::ostringstream oss;
451   oss << v;
452   ASSERT_EQ(oss.str(), "sin(1.f)");
453 
454   SimpleIRExprEval eval(v);
455   float v_ref = std::sin(1.0f);
456   float res = eval.value<float>();
457   ASSERT_NEAR(res, v_ref, 1e-6);
458 }
459 
TEST(Expr,UnaryMath01)460 TEST(Expr, UnaryMath01) {
461   struct TestConfig {
462     std::function<ExprHandle(const ExprHandle&)> func;
463     std::function<float(float)> ref_func;
464   };
465 
466   std::vector<TestConfig> test_configs = {
467       {[](const ExprHandle& v) { return sin(v); },
468        [](float v) { return std::sin(v); }},
469       {[](const ExprHandle& v) { return sin(v); },
470        [](float v) { return std::sin(v); }},
471       {[](const ExprHandle& v) { return tan(v); },
472        [](float v) { return std::tan(v); }},
473       {[](const ExprHandle& v) { return asin(v); },
474        [](float v) { return std::asin(v); }},
475       {[](const ExprHandle& v) { return acos(v); },
476        [](float v) { return std::acos(v); }},
477       {[](const ExprHandle& v) { return atan(v); },
478        [](float v) { return std::atan(v); }},
479       {[](const ExprHandle& v) { return sinh(v); },
480        [](float v) { return std::sinh(v); }},
481       {[](const ExprHandle& v) { return cosh(v); },
482        [](float v) { return std::cosh(v); }},
483       {[](const ExprHandle& v) { return tanh(v); },
484        [](float v) { return std::tanh(v); }},
485       {[](const ExprHandle& v) { return exp(v); },
486        [](float v) { return std::exp(v); }},
487       {[](const ExprHandle& v) { return tensorexpr::abs(v); },
488        [](float v) { return std::fabs(v); }},
489       {[](const ExprHandle& v) { return log(v); },
490        [](float v) { return std::log(v); }},
491       {[](const ExprHandle& v) { return log2(v); },
492        [](float v) { return std::log2(v); }},
493       {[](const ExprHandle& v) { return log10(v); },
494        [](float v) { return std::log10(v); }},
495       {[](const ExprHandle& v) { return erf(v); },
496        [](float v) { return std::erf(v); }},
497       {[](const ExprHandle& v) { return sqrt(v); },
498        [](float v) { return std::sqrt(v); }},
499       {[](const ExprHandle& v) { return rsqrt(v); },
500        [](float v) { return 1.0f / std::sqrt(v); }},
501       {[](const ExprHandle& v) { return ceil(v); },
502        [](float v) { return std::ceil(v); }},
503       {[](const ExprHandle& v) { return floor(v); },
504        [](float v) { return std::floor(v); }},
505       {[](const ExprHandle& v) { return round(v); },
506        [](float v) { return std::round(v); }},
507       {[](const ExprHandle& v) { return trunc(v); },
508        [](float v) { return std::trunc(v); }},
509   };
510 
511   for (const TestConfig& test_config : test_configs) {
512     const float input_v = 0.8765f;
513     ExprHandle v = test_config.func(ExprHandle(input_v));
514     float v_ref = test_config.ref_func(input_v);
515     SimpleIRExprEval eval(v);
516     ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
517   }
518 
519   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
520   for (float input_v : {std::nan("1"), 0., .5}) {
521     ExprHandle v = FloatImm::make(input_v);
522     SimpleIRExprEval eval(Intrinsics::make(kIsNan, v));
523     ASSERT_NEAR(eval.value<int>(), std::isnan(input_v), 0);
524   }
525 }
526 
TEST(Expr,BinaryMath01)527 TEST(Expr, BinaryMath01) {
528   struct TestConfig {
529     std::function<ExprHandle(const ExprHandle&, const ExprHandle&)> func;
530     std::function<float(float, float)> ref_func;
531   };
532 
533   std::vector<TestConfig> test_configs = {
534       {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); },
535        [](float v1, float v2) { return std::pow(v1, v2); }},
536       {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); },
537        [](float v1, float v2) { return std::fmod(v1, v2); }},
538   };
539 
540   for (const TestConfig& test_config : test_configs) {
541     const float v1 = 0.8765f;
542     float v2 = 1.2345f;
543     ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
544     float v_ref = test_config.ref_func(v1, v2);
545     SimpleIRExprEval eval(v_expr);
546     ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
547   }
548 }
549 
TEST(Expr,LogicalOps01)550 TEST(Expr, LogicalOps01) {
551   ExprHandle a(23);
552   ExprHandle b(11);
553   ExprHandle c(0.72f);
554   ExprHandle d(0.69f);
555   ExprHandle f1 = (a > b) && (c > d);
556   ExprHandle f2 = (a > b) && (c < d);
557   ExprHandle f3 = (a < b) && (c > d);
558   ExprHandle f4 = (a < b) && (c < d);
559   ExprHandle f5 = (a < b) || (c > d);
560   ExprHandle f6 = (a < b) || (c < d);
561   ExprHandle f7 = (a > b) || (c < d);
562   ExprHandle f8 = (a > b) || (c > d);
563 
564   SimpleIRExprEval eval1(f1);
565   SimpleIRExprEval eval2(f2);
566   SimpleIRExprEval eval3(f3);
567   SimpleIRExprEval eval4(f4);
568   SimpleIRExprEval eval5(f5);
569   SimpleIRExprEval eval6(f6);
570   SimpleIRExprEval eval7(f7);
571   SimpleIRExprEval eval8(f8);
572   ASSERT_EQ(eval1.value<int>(), 1);
573   ASSERT_EQ(eval2.value<int>(), 0);
574   ASSERT_EQ(eval3.value<int>(), 0);
575   ASSERT_EQ(eval4.value<int>(), 0);
576   ASSERT_EQ(eval5.value<int>(), 1);
577   ASSERT_EQ(eval6.value<int>(), 0);
578   ASSERT_EQ(eval7.value<int>(), 1);
579   ASSERT_EQ(eval8.value<int>(), 1);
580 }
581 
TEST(Expr,LogicalOps02)582 TEST(Expr, LogicalOps02) {
583   ExprHandle a(23);
584   ExprHandle b(11);
585   ExprHandle c(0.72f);
586   ExprHandle d(0.72f);
587 
588   ExprHandle f1 = (a > b) || (c > d);
589   ExprHandle f2 = (a > b) && (c <= d);
590   ExprHandle f3 = (a > b) && (c > d);
591   ExprHandle ff1 = f1 && f2;
592   ExprHandle ff2 = f2 || f3;
593 
594   SimpleIRExprEval eval1(ff1);
595   SimpleIRExprEval eval2(ff2);
596   ASSERT_EQ(eval1.value<int>(), 1);
597   ASSERT_EQ(eval2.value<int>(), 1);
598 }
599 
TEST(Expr,LogicalOps03)600 TEST(Expr, LogicalOps03) {
601   ExprHandle a(23);
602   ExprHandle b(11);
603   ExprHandle c(0.72f);
604   ExprHandle d(0.69f);
605 
606   // Bool types
607   ExprHandle bool_f1 = (a > b) && BoolImm::make(true);
608   ExprHandle bool_f2 = (c <= d) || BoolImm::make(true);
609 
610   // Int types
611   ExprHandle int_f1 = (a > b) && IntImm::make(1);
612   ExprHandle int_f2 = (c <= d) || IntImm::make(1);
613 
614   // Short types
615   ExprHandle short_f1 = (a > b) && ShortImm::make(1);
616   ExprHandle short_f2 = (c <= d) || ShortImm::make(1);
617 
618   // Long types
619   ExprHandle long_f1 = (a > b) && LongImm::make(1);
620   ExprHandle long_f2 = (c <= d) || LongImm::make(1);
621 
622   // Char types
623   ExprHandle char_f1 = (a > b) && CharImm::make(1);
624   ExprHandle char_f2 = (c <= d) || CharImm::make(1);
625 
626   // Byte types
627   ExprHandle byte_f1 = (a > b) && ByteImm::make(1);
628   ExprHandle byte_f2 = (c <= d) || ByteImm::make(1);
629 
630   SimpleIRExprEval eval1(bool_f1);
631   SimpleIRExprEval eval2(bool_f2);
632   SimpleIRExprEval eval3(int_f1);
633   SimpleIRExprEval eval4(int_f2);
634   SimpleIRExprEval eval5(short_f1);
635   SimpleIRExprEval eval6(short_f2);
636   SimpleIRExprEval eval7(long_f1);
637   SimpleIRExprEval eval8(long_f2);
638   SimpleIRExprEval eval9(char_f1);
639   SimpleIRExprEval eval10(char_f2);
640   SimpleIRExprEval eval11(byte_f1);
641   SimpleIRExprEval eval12(byte_f2);
642 
643   ASSERT_EQ(eval1.value<bool>(), true);
644   ASSERT_EQ(eval2.value<bool>(), true);
645   ASSERT_EQ(eval3.value<int>(), 1);
646   ASSERT_EQ(eval4.value<int>(), 1);
647   ASSERT_EQ(eval5.value<int16_t>(), 1);
648   ASSERT_EQ(eval6.value<int16_t>(), 1);
649   ASSERT_EQ(eval7.value<int64_t>(), 1);
650   ASSERT_EQ(eval8.value<int64_t>(), 1);
651   ASSERT_EQ(eval9.value<int8_t>(), 1);
652   ASSERT_EQ(eval10.value<int8_t>(), 1);
653   ASSERT_EQ(eval11.value<uint8_t>(), 1);
654   ASSERT_EQ(eval12.value<uint8_t>(), 1);
655 }
656 
TEST(Expr,BitwiseOps)657 TEST(Expr, BitwiseOps) {
658   ExprHandle a(59);
659   ExprHandle b(11);
660   ExprHandle c(101);
661   ExprHandle d(2);
662   ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
663 
664   SimpleIRExprEval eval(f);
665   ASSERT_EQ(eval.value<int>(), 11);
666 }
667 
TEST(Expr,DynamicShapeAdd)668 TEST(Expr, DynamicShapeAdd) {
669   auto testWithSize = [](int32_t size) {
670     VarHandle n("n", kInt);
671     BufHandle a("a", {n}, kFloat);
672     BufHandle b("b", {n}, kFloat);
673     BufHandle c("c", {n}, kFloat);
674     VarHandle i("i", kInt);
675     StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
676     std::vector<float> aData(size, 1.0f);
677     std::vector<float> bData(size, 2.0f);
678     std::vector<float> cData(size, 0.0f);
679     SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size);
680     ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
681   };
682   testWithSize(1);
683   testWithSize(16);
684   testWithSize(37);
685 }
686 
TEST(Expr,OutOfBounds)687 TEST(Expr, OutOfBounds) {
688   ExprHandle N(10);
689   ExprHandle start(0);
690   ExprHandle stop(15);
691   VarHandle i("i", kInt);
692 
693   BufHandle X("X", {N}, kInt);
694 
695   auto body = Store::make(X, {i}, i);
696   auto stmt = For::make(i, start, stop, body);
697 
698   PaddedBuffer<int> data(20);
699 
700   EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
701 }
702 
TEST(Expr,OutOfBounds2d)703 TEST(Expr, OutOfBounds2d) {
704   std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
705   for (auto sizes : size_options) {
706     ExprHandle N(sizes.first);
707     ExprHandle M(sizes.second);
708     ExprHandle start(0);
709     ExprHandle stopInner(15);
710     ExprHandle stopOuter(15);
711     VarHandle i("i", kInt);
712     VarHandle j("j", kInt);
713 
714     BufHandle X("X", {N, M}, kInt);
715 
716     auto body = Store::make(X, {i, j}, i);
717     auto inner = For::make(j, start, stopInner, body);
718     auto stmt = For::make(i, start, stopOuter, inner);
719 
720     PaddedBuffer<int> data(400);
721 
722     EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
723   }
724 }
725 
TEST(Expr,OutOfBounds2dFlattenedIndex)726 TEST(Expr, OutOfBounds2dFlattenedIndex) {
727   ExprHandle buf_size(149);
728   ExprHandle start(0);
729   ExprHandle stopInner(15);
730   ExprHandle stopOuter(10);
731   VarHandle i("i", kInt);
732   VarHandle j("j", kInt);
733 
734   BufHandle X("X", {buf_size}, kInt);
735 
736   auto idx = Add::make(Mul::make(i, stopInner), j);
737   auto body = Store::make(X, {idx}, i);
738   auto inner = For::make(j, start, stopInner, body);
739   auto stmt = For::make(i, start, stopOuter, inner);
740 
741   PaddedBuffer<int> data(400);
742 
743   EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
744 }
745 
testCond01()746 void testCond01() {
747   const int N = 16;
748   PaddedBuffer<float> a_v(N);
749   BufHandle a_buf("a", {N}, kFloat);
750   VarHandle index = VarHandle("index", kInt);
751   StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
752   StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
753   ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
754   StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3);
755   StmtPtr for_stmt = For::make(index, 0, N, assign);
756   SimpleIREvaluator(for_stmt, {a_buf})(a_v);
757 
758   PaddedBuffer<float> a_ref(N);
759   for (const auto i : c10::irange(N)) {
760     if (i % 2 == 0) {
761       a_ref(i) = i * 2;
762     } else {
763       a_ref(i) = i * 3;
764     }
765   }
766   ExpectAllNear(a_v, a_ref, 1e-5);
767 }
768 
testIfThenElse01()769 void testIfThenElse01() {
770   ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f));
771 
772   std::ostringstream oss;
773   oss << v;
774   ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)");
775 
776   SimpleIRExprEval eval(v);
777   ASSERT_EQ(eval.value<float>(), 1.0f);
778 }
779 
testIfThenElse02()780 void testIfThenElse02() {
781   ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f));
782 
783   std::ostringstream oss;
784   oss << v;
785   ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
786 
787   SimpleIRExprEval eval(v);
788   ASSERT_EQ(eval.value<float>(), 2.0f);
789 }
790 
testIfThenElse03()791 void testIfThenElse03() {
792   ExprHandle v =
793       ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f));
794 
795   std::ostringstream oss;
796   oss << v;
797   ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
798 
799   SimpleIRExprEval eval(v);
800   ASSERT_EQ(eval.value<float>(), 2.0f);
801 }
802 
testStmtClone()803 void testStmtClone() {
804   const int N = 16;
805 
806   BufHandle a_buf("a", {N}, kInt);
807   VarHandle index = VarHandle("index", kInt);
808   StmtPtr body = a_buf.store({index}, 5);
809   StmtPtr loop = For::make(index, 0, N, body);
810 
811   StmtPtr cloned_loop = Stmt::clone(loop);
812   std::vector<int> orig_loop_results(N);
813   std::vector<int> cloned_loop_results(N);
814   SimpleIREvaluator(loop, {a_buf})(orig_loop_results);
815   SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results);
816 
817   assertAllEqual(orig_loop_results, 5);
818   assertAllEqual(cloned_loop_results, 5);
819 
820   // Let's add another assign to the body in the cloned loop and verify that the
821   // original statement hasn't changed while the cloned one has.
822   StmtPtr body_addition = a_buf.store({index}, 33);
823   BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body());
824   cloned_body->append_stmt(body_addition);
825 
826   std::vector<int> orig_loop_results_after_mutation(N);
827   std::vector<int> cloned_loop_results_after_mutation(N);
828   SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation);
829   SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation);
830 
831   assertAllEqual(orig_loop_results_after_mutation, 5);
832   assertAllEqual(cloned_loop_results_after_mutation, 33);
833 }
834 
835 } // namespace jit
836 } // namespace torch
837