xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_cpp_codegen.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include "test/cpp/tensorexpr/test_base.h"
4 
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
7 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
8 #include <torch/csrc/jit/tensorexpr/stmt.h>
9 #include <torch/csrc/jit/tensorexpr/tensor.h>
10 #include <torch/csrc/jit/testing/file_check.h>
11 
12 namespace torch {
13 namespace jit {
14 
15 using namespace torch::jit::tensorexpr;
16 
17 #define STR_CHECK(node, expected) \
18   std::stringstream ss;           \
19   CppPrinter printer(&ss);        \
20   printer.visit(node);            \
21   ASSERT_EQ(ss.str(), expected)
22 
23 #define FILE_CHECK(node, pattern) \
24   std::stringstream ss;           \
25   CppPrinter printer(&ss);        \
26   printer.visit(node);            \
27   torch::jit::testing::FileCheck().run(pattern, ss.str())
28 
TEST(CppPrinter,IntImm)29 TEST(CppPrinter, IntImm) {
30   auto i = alloc<IntImm>(10);
31   STR_CHECK(i, "10");
32 }
33 
TEST(CppPrinter,FloatImm)34 TEST(CppPrinter, FloatImm) {
35   auto f = alloc<FloatImm>(10);
36   STR_CHECK(f, "10.f");
37 }
38 
TEST(CppPrinter,FloatImm1)39 TEST(CppPrinter, FloatImm1) {
40   auto f = alloc<FloatImm>(10);
41   STR_CHECK(f, "10.f");
42 }
43 
TEST(CppPrinter,DoubleImm)44 TEST(CppPrinter, DoubleImm) {
45   auto d = alloc<DoubleImm>(10);
46   STR_CHECK(d, "10.0");
47 }
48 
TEST(CppPrinter,DoubleImm1)49 TEST(CppPrinter, DoubleImm1) {
50   auto d = alloc<DoubleImm>(10.1);
51   STR_CHECK(d, "10.1");
52 }
53 
TEST(CppPrinter,HalfImm)54 TEST(CppPrinter, HalfImm) {
55   auto h = alloc<HalfImm>(10);
56   STR_CHECK(h, "10");
57 }
58 
TEST(CppPrinter,Add)59 TEST(CppPrinter, Add) {
60   auto add = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
61   STR_CHECK(add, "1 + 2");
62 }
63 
TEST(CppPrinter,AddExpr1)64 TEST(CppPrinter, AddExpr1) {
65   auto add = alloc<Add>(
66       alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
67       alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
68   STR_CHECK(add, "(0 + 1) + (2 - 3)");
69 }
70 
TEST(CppPrinter,AddExpr2)71 TEST(CppPrinter, AddExpr2) {
72   auto add = alloc<Add>(
73       alloc<Mul>(alloc<IntImm>(0), alloc<IntImm>(1)),
74       alloc<Sub>(alloc<IntImm>(2), alloc<IntImm>(3)));
75   STR_CHECK(add, "0 * 1 + (2 - 3)");
76 }
77 
TEST(CppPrinter,AddExpr3)78 TEST(CppPrinter, AddExpr3) {
79   auto add = alloc<Add>(
80       alloc<Add>(alloc<IntImm>(0), alloc<IntImm>(1)),
81       alloc<Div>(alloc<IntImm>(2), alloc<IntImm>(3)));
82   STR_CHECK(add, "(0 + 1) + 2 / 3");
83 }
84 
TEST(CppPrinter,Mod)85 TEST(CppPrinter, Mod) {
86   auto mod = alloc<Mod>(alloc<IntImm>(1), alloc<IntImm>(2));
87   STR_CHECK(mod, "1 % 2");
88 }
89 
TEST(CppPrinter,ModFloat)90 TEST(CppPrinter, ModFloat) {
91   auto mod = alloc<Mod>(alloc<FloatImm>(1), alloc<FloatImm>(2));
92   STR_CHECK(mod, "std::fmod(1.f, 2.f)");
93 }
94 
TEST(CppPrinter,Max)95 TEST(CppPrinter, Max) {
96   auto max = alloc<Max>(alloc<IntImm>(1), alloc<IntImm>(2), false);
97   STR_CHECK(max, "std::max(1, 2)");
98 }
99 
TEST(CppPrinter,MaxFloat)100 TEST(CppPrinter, MaxFloat) {
101   auto max = alloc<Max>(alloc<FloatImm>(1), alloc<FloatImm>(2), false);
102   STR_CHECK(max, "std::max(1.f, 2.f)");
103 }
104 
TEST(CppPrinter,MaxHalf)105 TEST(CppPrinter, MaxHalf) {
106   auto max = alloc<Max>(alloc<HalfImm>(1), alloc<HalfImm>(2), false);
107   STR_CHECK(max, "(1 < 2) ? 2 : 1");
108 }
109 
TEST(CppPrinter,And)110 TEST(CppPrinter, And) {
111   auto v = alloc<And>(alloc<IntImm>(1), alloc<IntImm>(2));
112   STR_CHECK(v, "1 & 2");
113 }
114 
TEST(CppPrinter,CompareSelect)115 TEST(CppPrinter, CompareSelect) {
116   auto cs = alloc<CompareSelect>(
117       alloc<IntImm>(1),
118       alloc<IntImm>(2),
119       alloc<FloatImm>(1),
120       alloc<FloatImm>(2),
121       CompareSelectOperation::kLE);
122   STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)");
123 }
124 
TEST(CppPrinter,IfThenElse)125 TEST(CppPrinter, IfThenElse) {
126   auto cond = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
127   auto true_value = alloc<Sub>(alloc<IntImm>(0), alloc<IntImm>(1));
128   auto false_value = alloc<Mul>(alloc<IntImm>(2), alloc<IntImm>(3));
129   auto v = alloc<IfThenElse>(cond, true_value, false_value);
130   STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)");
131 }
132 
TEST(CppPrinter,AllocateFree)133 TEST(CppPrinter, AllocateFree) {
134   BufHandle buf("x", {2, 3}, kInt);
135   AllocatePtr alloc = Allocate::make(buf);
136   FreePtr free = Free::make(buf);
137   BlockPtr block = Block::make({alloc, free});
138 
139   const std::string pattern = R"(
140    # CHECK: {
141    # CHECK:   int* x = static_cast<int*>(malloc(24));
142    # CHECK:   free(x);
143    # CHECK: }
144   )";
145   FILE_CHECK(block, pattern);
146 }
147 
TEST(CppPrinter,LoadStore)148 TEST(CppPrinter, LoadStore) {
149   BufHandle a("A", {2, 3}, kInt);
150   BufHandle b("B", {3, 4}, kInt);
151   auto store = b.store({2, 2}, a.load(1, 1));
152   STR_CHECK(
153       store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n");
154 }
155 
TEST(CppPrinter,Var)156 TEST(CppPrinter, Var) {
157   auto var = alloc<Var>("x", kInt);
158   STR_CHECK(var, "x");
159 }
160 
TEST(CppPrinter,Cast)161 TEST(CppPrinter, Cast) {
162   auto cast = alloc<Cast>(kFloat, alloc<IntImm>(1));
163   STR_CHECK(cast, "static_cast<float>(1)");
164 }
165 
TEST(CppPrinter,BitCast)166 TEST(CppPrinter, BitCast) {
167   auto cast = alloc<BitCast>(kInt, alloc<FloatImm>(20));
168   STR_CHECK(cast, "std::bitcast<float, int>(20.f)");
169 }
170 
TEST(CppPrinter,Let)171 TEST(CppPrinter, Let) {
172   auto var = alloc<Var>("x", kFloat);
173   auto val = alloc<FloatImm>(2);
174   auto let = alloc<Let>(var, val);
175   STR_CHECK(let, "float x = 2.f;\n");
176 }
177 
TEST(CppPrinter,For)178 TEST(CppPrinter, For) {
179   constexpr int N = 1024;
180   BufHandle a("A", {N}, kInt);
181   BufHandle b("B", {N}, kInt);
182   BufHandle c("C", {N}, kInt);
183   VarHandle i("i", kInt);
184   auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));
185   const std::string pattern = R"(
186    # CHECK: for (int i = 0; i < 1024; i++) {
187    # CHECK:   C[i] = (A[i]) + (B[i]);
188    # CHECK: }
189   )";
190   FILE_CHECK(f, pattern);
191 }
192 
TEST(CppPrinter,Cond)193 TEST(CppPrinter, Cond) {
194   BufHandle x("X", {1}, kInt);
195   auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
196   auto cond =
197       Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
198   const std::string pattern = R"(
199     # CHECK: if (((X[0] < 10) ? 1 : 0)) {
200     # CHECK:   X[0] = (X[0]) + 1;
201     # CHECK: } else {
202     # CHECK:   X[0] = (X[0]) - 1;
203     # CHECK: }
204   )";
205   FILE_CHECK(cond, pattern);
206 }
207 
TEST(CppPrinter,Intrinsics)208 TEST(CppPrinter, Intrinsics) {
209   const std::unordered_set<IntrinsicsOp, std::hash<int>> unsupported_ops{
210       kRand, kSigmoid};
211   for (const auto i : c10::irange(static_cast<uint32_t>(kMaxIntrinsicsOp))) {
212     IntrinsicsOp op = static_cast<IntrinsicsOp>(i);
213     if (unsupported_ops.count(op)) {
214       continue;
215     }
216 
217     if (Intrinsics::OpArgCount(op) == 1) {
218       auto v = alloc<Intrinsics>(op, alloc<FloatImm>(2.0f));
219       STR_CHECK(v, "std::" + v->func_name() + "(2.f)");
220     } else {
221       auto v =
222           alloc<Intrinsics>(op, alloc<FloatImm>(1.0f), alloc<FloatImm>(2.0f));
223       STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)");
224     }
225   }
226 }
227 
TEST(CppPrinter,ExternalCall)228 TEST(CppPrinter, ExternalCall) {
229   std::vector<ExprPtr> dims{alloc<IntImm>(2), alloc<IntImm>(2)};
230   auto output = alloc<Buf>("out", dims, kFloat);
231   auto buf_arg1 = alloc<Buf>("a", dims, kFloat);
232   auto buf_arg2 = alloc<Buf>("b", dims, kFloat);
233   auto scalar_arg = alloc<Add>(alloc<IntImm>(1), alloc<IntImm>(2));
234   std::vector<BufPtr> buf_args{buf_arg1, buf_arg2};
235   std::vector<ExprPtr> scalar_args{scalar_arg};
236   auto call =
237       alloc<ExternalCall>(output, "nnc_aten_matmul", buf_args, scalar_args);
238   const std::string pattern = R"(
239    # CHECK: {
240    # CHECK:   void* buf_ptrs[]{out, a, b};
241    # CHECK:   int64_t buf_ranks[]{2, 2, 2};
242    # CHECK:   int64_t buf_dims[]{2, 2, 2, 2, 2, 2};
243    # CHECK:   int8_t buf_dtypes[]{6, 6, 6};
244    # CHECK:   int64_t extra_args[]{1 + 2};
245    # CHECK:   nnc_aten_matmul(
246    # CHECK:       3,
247    # CHECK:       buf_ptrs,
248    # CHECK:       buf_ranks,
249    # CHECK:       buf_dims,
250    # CHECK:       buf_dtypes,
251    # CHECK:       1,
252    # CHECK:       extra_args);
253    # CHECK: }
254   )";
255   FILE_CHECK(call, pattern);
256 }
257 
258 } // namespace jit
259 } // namespace torch
260