1 #include <gtest/gtest.h>
2
3 #include <stdexcept>
4 #include "test/cpp/tensorexpr/test_base.h"
5
6 #include <torch/csrc/jit/tensorexpr/expr.h>
7 #include <torch/csrc/jit/tensorexpr/ir.h>
8 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
9 #include <torch/csrc/jit/tensorexpr/loopnest.h>
10 #include <torch/csrc/jit/tensorexpr/tensor.h>
11 #include <torch/csrc/jit/testing/file_check.h>
12
13 #include <sstream>
14 namespace torch {
15 namespace jit {
16
17 using namespace torch::jit::tensorexpr;
18
TEST(IRPrinter,BasicValueTest)19 TEST(IRPrinter, BasicValueTest) {
20 ExprHandle a = IntImm::make(2), b = IntImm::make(3);
21 ExprHandle c = Add::make(a, b);
22
23 std::stringstream ss;
24 ss << c;
25 ASSERT_EQ(ss.str(), "2 + 3");
26 }
27
TEST(IRPrinter,BasicValueTest02)28 TEST(IRPrinter, BasicValueTest02) {
29 ExprHandle a(2.0f);
30 ExprHandle b(3.0f);
31 ExprHandle c(4.0f);
32 ExprHandle d(5.0f);
33 ExprHandle f = (a + b) - (c + d);
34
35 std::stringstream ss;
36 ss << f;
37 ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
38 }
39
TEST(IRPrinter,CastTest)40 TEST(IRPrinter, CastTest) {
41 VarHandle x("x", kHalf);
42 VarHandle y("y", kFloat);
43 ExprHandle body = ExprHandle(2.f) +
44 (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y);
45
46 std::stringstream ss;
47 ss << body;
48 ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)");
49 }
50
TEST(IRPrinter,FunctionName)51 TEST(IRPrinter, FunctionName) {
52 int M = 4;
53 int N = 20;
54
55 Tensor producer = Compute(
56 "producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
57 return m * n;
58 });
59
60 Tensor chunk_0 = Compute(
61 "chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
62 return producer.load(m, n);
63 });
64
65 Tensor chunk_1 = Compute(
66 "chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
67 return producer.load(m, n + ExprHandle(N / 2));
68 });
69
70 Tensor consumer = Compute(
71 "consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) {
72 return i * chunk_1.load(i, j);
73 });
74
75 LoopNest l({chunk_0, chunk_1, consumer});
76 auto body = LoopNest::sanitizeNames(l.root_stmt());
77
78 std::stringstream ss;
79 ss << *body;
80
81 const std::string& verification_pattern =
82 R"IR(
83 # CHECK: for (int i_2
84 # CHECK: for (int j_2
85 # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR";
86
87 torch::jit::testing::FileCheck().run(verification_pattern, ss.str());
88 }
89 } // namespace jit
90 } // namespace torch
91