1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/runtime/static/impl.h>
3 #include <torch/torch.h>
4 #include <thread>
5
6 #include "test_utils.h"
7
8 using namespace torch;
9 using namespace torch::jit;
10 using namespace torch::jit::test;
11
TEST(CpuFusion,Simple)12 TEST(CpuFusion, Simple) {
13 const auto simple_script = R"JIT(
14 def forward(self, a, b):
15 return (a + b).relu().tanh()
16 )JIT";
17
18 Module m("module");
19 m.define(simple_script);
20
21 StaticModuleOptions opts; // start with the defaults.
22 opts.enable_tensorexpr_fusion = true;
23
24 auto input1 = at::randn({2, 3});
25 auto input2 = at::ones({2, 3});
26
27 auto smodule = StaticModule(m, /* is_frozen */ false, opts, {input1, input2});
28 StaticRuntime runtime(smodule);
29
30 // Test with sample inputs
31 {
32 auto actual = runtime({input1, input2}, {});
33 auto expect = at::tanh(at::relu(input1 + input2));
34 EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
35 }
36
37 // Test with different inputs
38 {
39 auto new_input1 = at::randn({5, 14});
40 auto new_input2 = at::randn({5, 14});
41 auto actual = runtime({new_input1, new_input2}, {});
42 auto expect = at::tanh(at::relu(new_input1 + new_input2));
43 EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
44 }
45 }
46
TEST(CpuFusion,FallbackGraph)47 TEST(CpuFusion, FallbackGraph) {
48 const auto simple_script = R"JIT(
49 def forward(self, a, b):
50 return (a + b).relu().tanh()
51 )JIT";
52
53 Module m("module");
54 m.define(simple_script);
55
56 StaticModuleOptions opts; // start with the defaults.
57 opts.enable_tensorexpr_fusion = true;
58
59 auto sample_input1 = at::randn({2, 3});
60 auto sample_input2 = at::ones({2, 3});
61 auto smodule = StaticModule(
62 m, /* is_frozen */ false, opts, {sample_input1, sample_input2});
63
64 StaticRuntime runtime(smodule);
65
66 // The sample inputs above were contiguous. Now, use a strided input
67 // to trigger running the fallback graph.
68 {
69 auto input1 = at::narrow(at::randn({2, 6}), 1, 0, 3);
70 auto input2 = at::ones({2, 3});
71 auto expect = at::tanh(at::relu(input1 + input2));
72 auto actual = runtime({input1, input2}, {});
73 EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
74 }
75
76 // Test with strided inputs of different size.
77 {
78 auto input1 = at::narrow(at::randn({10, 30}), 1, 0, 25);
79 auto input2 = at::randn({10, 25});
80 auto expect = at::tanh(at::relu(input1 + input2));
81 auto actual = runtime({input1, input2}, {});
82 EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
83 }
84 }
85
TEST(CpuFusion,ParallelRuntimes)86 TEST(CpuFusion, ParallelRuntimes) {
87 const auto simple_script = R"JIT(
88 def forward(self, a, b):
89 return (a + b).relu().tanh()
90 )JIT";
91
92 Module m("module");
93 m.define(simple_script);
94
95 StaticModuleOptions opts; // start with the defaults.
96 opts.enable_tensorexpr_fusion = true;
97
98 auto sample_input1 = at::randn({2, 3});
99 auto sample_input2 = at::ones({2, 3});
100 auto smodule = StaticModule(
101 m, /* is_frozen */ false, opts, {sample_input1, sample_input2});
102
103 constexpr size_t kNumThreads = 2;
104 std::vector<std::vector<std::pair<int, int>>> all_inputs;
105 for (size_t id = 0; id < kNumThreads; ++id) {
106 std::vector<std::pair<int, int>> thread_input = {
107 {id, id + 1},
108 {id + 10, id + 11},
109 {id + 20, id + 21},
110 {id + 30, id + 31},
111 {id + 40, id + 41},
112 {id + 50, id + 51},
113 {id + 60, id + 61},
114 {id + 70, id + 71}};
115 all_inputs.emplace_back(std::move(thread_input));
116 }
117
118 auto exec_runtime = [&](size_t tid) {
119 const auto& inputs = all_inputs[tid];
120 StaticRuntime runtime(smodule);
121 for (const auto& inp : inputs) {
122 auto a = at::randn({inp.first, inp.second});
123 auto b = at::randn({inp.first, inp.second});
124 auto expect = at::tanh(at::relu(a + b));
125 auto actual = runtime({a, b}, {});
126 EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
127 }
128 };
129
130 std::vector<std::thread> threads;
131 for (size_t id = 0; id < kNumThreads; ++id) {
132 threads.emplace_back(exec_runtime, id);
133 }
134
135 for (auto& t : threads) {
136 t.join();
137 }
138 }
139