xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/test_cpu_fusion.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/runtime/static/impl.h>
3 #include <torch/torch.h>
4 #include <thread>
5 
6 #include "test_utils.h"
7 
8 using namespace torch;
9 using namespace torch::jit;
10 using namespace torch::jit::test;
11 
TEST(CpuFusion,Simple)12 TEST(CpuFusion, Simple) {
13   const auto simple_script = R"JIT(
14     def forward(self, a, b):
15         return (a + b).relu().tanh()
16   )JIT";
17 
18   Module m("module");
19   m.define(simple_script);
20 
21   StaticModuleOptions opts; // start with the defaults.
22   opts.enable_tensorexpr_fusion = true;
23 
24   auto input1 = at::randn({2, 3});
25   auto input2 = at::ones({2, 3});
26 
27   auto smodule = StaticModule(m, /* is_frozen */ false, opts, {input1, input2});
28   StaticRuntime runtime(smodule);
29 
30   // Test with sample inputs
31   {
32     auto actual = runtime({input1, input2}, {});
33     auto expect = at::tanh(at::relu(input1 + input2));
34     EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
35   }
36 
37   // Test with different inputs
38   {
39     auto new_input1 = at::randn({5, 14});
40     auto new_input2 = at::randn({5, 14});
41     auto actual = runtime({new_input1, new_input2}, {});
42     auto expect = at::tanh(at::relu(new_input1 + new_input2));
43     EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
44   }
45 }
46 
TEST(CpuFusion,FallbackGraph)47 TEST(CpuFusion, FallbackGraph) {
48   const auto simple_script = R"JIT(
49     def forward(self, a, b):
50         return (a + b).relu().tanh()
51   )JIT";
52 
53   Module m("module");
54   m.define(simple_script);
55 
56   StaticModuleOptions opts; // start with the defaults.
57   opts.enable_tensorexpr_fusion = true;
58 
59   auto sample_input1 = at::randn({2, 3});
60   auto sample_input2 = at::ones({2, 3});
61   auto smodule = StaticModule(
62       m, /* is_frozen */ false, opts, {sample_input1, sample_input2});
63 
64   StaticRuntime runtime(smodule);
65 
66   // The sample inputs above were contiguous. Now, use a strided input
67   // to trigger running the fallback graph.
68   {
69     auto input1 = at::narrow(at::randn({2, 6}), 1, 0, 3);
70     auto input2 = at::ones({2, 3});
71     auto expect = at::tanh(at::relu(input1 + input2));
72     auto actual = runtime({input1, input2}, {});
73     EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
74   }
75 
76   // Test with strided inputs of different size.
77   {
78     auto input1 = at::narrow(at::randn({10, 30}), 1, 0, 25);
79     auto input2 = at::randn({10, 25});
80     auto expect = at::tanh(at::relu(input1 + input2));
81     auto actual = runtime({input1, input2}, {});
82     EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
83   }
84 }
85 
TEST(CpuFusion,ParallelRuntimes)86 TEST(CpuFusion, ParallelRuntimes) {
87   const auto simple_script = R"JIT(
88     def forward(self, a, b):
89         return (a + b).relu().tanh()
90   )JIT";
91 
92   Module m("module");
93   m.define(simple_script);
94 
95   StaticModuleOptions opts; // start with the defaults.
96   opts.enable_tensorexpr_fusion = true;
97 
98   auto sample_input1 = at::randn({2, 3});
99   auto sample_input2 = at::ones({2, 3});
100   auto smodule = StaticModule(
101       m, /* is_frozen */ false, opts, {sample_input1, sample_input2});
102 
103   constexpr size_t kNumThreads = 2;
104   std::vector<std::vector<std::pair<int, int>>> all_inputs;
105   for (size_t id = 0; id < kNumThreads; ++id) {
106     std::vector<std::pair<int, int>> thread_input = {
107         {id, id + 1},
108         {id + 10, id + 11},
109         {id + 20, id + 21},
110         {id + 30, id + 31},
111         {id + 40, id + 41},
112         {id + 50, id + 51},
113         {id + 60, id + 61},
114         {id + 70, id + 71}};
115     all_inputs.emplace_back(std::move(thread_input));
116   }
117 
118   auto exec_runtime = [&](size_t tid) {
119     const auto& inputs = all_inputs[tid];
120     StaticRuntime runtime(smodule);
121     for (const auto& inp : inputs) {
122       auto a = at::randn({inp.first, inp.second});
123       auto b = at::randn({inp.first, inp.second});
124       auto expect = at::tanh(at::relu(a + b));
125       auto actual = runtime({a, b}, {});
126       EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
127     }
128   };
129 
130   std::vector<std::thread> threads;
131   for (size_t id = 0; id < kNumThreads; ++id) {
132     threads.emplace_back(exec_runtime, id);
133   }
134 
135   for (auto& t : threads) {
136     t.join();
137   }
138 }
139