xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_conv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
4 #include <torch/csrc/jit/tensorexpr/loopnest.h>
5 #include <torch/csrc/jit/tensorexpr/operators/conv2d.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7 #include <torch/torch.h>
8 
9 namespace torch {
10 namespace jit {
11 
12 namespace te = torch::jit::tensorexpr;
13 namespace F = torch::nn::functional;
14 
15 #ifdef TORCH_ENABLE_LLVM
16 
17 // Generate test data with few bits of precision, to minimize error
18 // accumulation from floating-point reordering.
genTestData(c10::IntArrayRef args)19 static at::Tensor genTestData(c10::IntArrayRef args) {
20   return at::trunc(at::randn(args) * 256.0f) / 256.0f;
21 }
22 
TEST(Conv,DepthwiseConv2D)23 TEST(Conv, DepthwiseConv2D) {
24   constexpr int N = 1, C = 72, H = 56, W = 56;
25   constexpr int K = 72, R = 3, S = 3;
26   constexpr int kPad = 1, kStride = 2, kGroups = C;
27   constexpr int CperG = C / kGroups;
28 
29   te::BufHandle input("input", {N, C, H, W}, te::kFloat);
30   te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);
31   te::BufHandle bias("bias", {K}, te::kFloat);
32   te::Tensor output =
33       te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups);
34 
35   te::LoopNest loop({output});
36   loop.simplify();
37   loop.prepareForCodegen();
38   te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output});
39 
40   auto it = genTestData({N, C, H, W});
41   auto wt = genTestData({K, CperG, R, S});
42   auto bt = genTestData({K});
43   auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups);
44   auto ot = at::zeros_like(ref);
45   cg.call(
46       {it.data_ptr<float>(),
47        wt.data_ptr<float>(),
48        bt.data_ptr<float>(),
49        ot.data_ptr<float>()});
50 
51   ASSERT_TRUE(at::allclose(ref, ot));
52 }
53 
TEST(Conv,DepthwiseConv2DNoBias)54 TEST(Conv, DepthwiseConv2DNoBias) {
55   constexpr int N = 1, C = 72, H = 56, W = 56;
56   constexpr int K = 72, R = 3, S = 3;
57   constexpr int kPad = 1, kStride = 2, kGroups = C;
58   constexpr int CperG = C / kGroups;
59 
60   te::BufHandle input("input", {N, C, H, W}, te::kFloat);
61   te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);
62   te::Tensor output =
63       te::conv2d_depthwise(input, weight, kStride, kPad, kGroups);
64 
65   te::LoopNest loop({output});
66   loop.simplify();
67   loop.prepareForCodegen();
68   te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output});
69 
70   auto it = genTestData({N, C, H, W});
71   auto wt = genTestData({K, CperG, R, S});
72   auto ref =
73       at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);
74   auto ot = at::zeros_like(ref);
75   cg.call({it.data_ptr<float>(), wt.data_ptr<float>(), ot.data_ptr<float>()});
76 
77   ASSERT_TRUE(at::allclose(ref, ot));
78 }
79 
TEST(Conv,DepthwiseConv2DDynamicShapes)80 TEST(Conv, DepthwiseConv2DDynamicShapes) {
81   te::VarHandle N_var("N", te::kInt);
82   te::VarHandle C_var("C", te::kInt);
83   te::VarHandle H_var("H", te::kInt);
84   te::VarHandle W_var("W", te::kInt);
85   te::VarHandle K_var("K", te::kInt);
86   te::VarHandle CperG_var("CperG", te::kInt);
87   te::VarHandle R_var("R", te::kInt);
88   te::VarHandle S_var("S", te::kInt);
89   te::VarHandle kPad_var("kPad", te::kInt);
90   te::VarHandle kStride_var("kStride", te::kInt);
91   te::VarHandle kGroups_var("kGroups", te::kInt);
92 
93   te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat);
94   te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat);
95   te::Tensor output = te::conv2d_depthwise(
96       input,
97       weight,
98       N_var,
99       C_var,
100       H_var,
101       W_var,
102       K_var,
103       CperG_var,
104       R_var,
105       S_var,
106       kStride_var,
107       kPad_var,
108       kGroups_var);
109 
110   te::LoopNest loop({output});
111   loop.simplify();
112   loop.prepareForCodegen();
113   std::vector<te::CodeGen::BufferArg> buffer_args = {
114       input,
115       weight,
116       N_var,
117       C_var,
118       H_var,
119       W_var,
120       K_var,
121       CperG_var,
122       R_var,
123       S_var,
124       kPad_var,
125       kStride_var,
126       kGroups_var,
127       output};
128   te::LLVMCodeGen cg(loop.root_stmt(), buffer_args);
129 
130   constexpr int N = 1, C = 72, H = 56, W = 56;
131   constexpr int K = 72, R = 3, S = 3;
132   constexpr int kPad = 1, kStride = 2, kGroups = C;
133   constexpr int CperG = C / kGroups;
134 
135   auto it = genTestData({N, C, H, W});
136   auto wt = genTestData({K, CperG, R, S});
137   auto ref =
138       at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);
139   auto ot = at::zeros_like(ref);
140   std::vector<te::CodeGen::CallArg> call_args = {
141       it.data_ptr<float>(),
142       wt.data_ptr<float>(),
143       N,
144       C,
145       H,
146       W,
147       K,
148       CperG,
149       R,
150       S,
151       kPad,
152       kStride,
153       kGroups,
154       ot.data_ptr<float>()};
155   cg.call(call_args);
156 
157   ASSERT_TRUE(at::allclose(ref, ot));
158 }
159 
160 #endif
161 
TEST(Conv,Conv2D)162 TEST(Conv, Conv2D) {
163   // Input dimensions.
164   constexpr int N = 1;
165   constexpr int C = 3;
166   constexpr int H = 11;
167   constexpr int W = 11;
168 
169   // Filter dimensions.
170   constexpr int K = 8;
171   constexpr int R = 3;
172   constexpr int S = 3;
173 
174   // Output dims.
175   constexpr int OH = H - R + 1;
176   constexpr int OW = W - S + 1;
177 
178   // Compute reference result.
179   at::Tensor input = torch::randn({N, C, H, W});
180   at::Tensor filter = torch::randn({K, C, R, S});
181   at::Tensor ref = F::conv2d(input, filter);
182 
183   // Double check the output size is as expected.
184   ASSERT_EQ(ref.size(0), N);
185   ASSERT_EQ(ref.size(1), K);
186   ASSERT_EQ(ref.size(2), OH);
187   ASSERT_EQ(ref.size(3), OW);
188 
189   te::BufHandle inputB("input", {N, C, H, W}, te::kFloat);
190   te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat);
191 
192   te::Tensor conv = te::Reduce(
193       "conv",
194       {N, K, OH, OW},
195       te::Sum(),
196       // FIXME: We have to use a `std::vector` parameter here and then unpack
197       // it, because we don't have an overload allowing for an arbitrary number
198       // of ExprHandle/VarHandle parameters.
199       [&](const std::vector<te::VarHandle>& v) {
200         auto const& n = v[0];
201         auto const& k = v[1];
202         auto const& oh = v[2];
203         auto const& ow = v[3];
204         auto const& c = v[4];
205         auto const& r = v[5];
206         auto const& s = v[6];
207         // FIXME: We have to use `call` and construct a `std::vector` here
208         // because the `operator()` overload is only specialized for a small
209         // number of arguments.
210         return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s);
211       },
212       // FIXME: If you forget one of the reduction dims, you get a segfault.
213       // Could that be caught by a verifier?
214       {C, R, S});
215 
216   // FIXME: It'd be nice to have a single header that pulls in things like
217   // LoopNest, IRSimplifier, etc.
218   te::LoopNest loop({conv});
219   loop.prepareForCodegen();
220   te::StmtPtr s = loop.root_stmt();
221   s = te::IRSimplifier::simplify(s);
222 
223   at::Tensor result = at::empty_like(ref);
224   te::SimpleIREvaluator cg(s, {inputB, filterB, conv});
225   cg.call(
226       {input.data_ptr<float>(),
227        filter.data_ptr<float>(),
228        result.data_ptr<float>()});
229 
230   ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3));
231 }
232 
233 } // namespace jit
234 } // namespace torch
235