xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/test_static_module.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/irparser.h>
4 #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
5 #include <torch/csrc/jit/runtime/static/fusion.h>
6 #include <torch/csrc/jit/runtime/static/impl.h>
7 #include <torch/csrc/jit/runtime/static/memory_planner.h>
8 #include <torch/csrc/jit/runtime/static/ops.h>
9 #include <torch/csrc/jit/runtime/static/passes.h>
10 #include <memory>
11 
12 #include "deep_wide_pt.h"
13 #include "test_utils.h"
14 
15 using namespace torch;
16 using namespace torch::jit;
17 using namespace torch::jit::test;
18 
19 C10_DECLARE_bool(static_runtime_disable_debug_memory_overlap_check);
20 
21 namespace {
22 
makeStaticModuleFromScript(const std::string & script)23 StaticModule makeStaticModuleFromScript(const std::string& script) {
24   Module m("module");
25   m.define(script);
26   return StaticModule(m);
27 }
28 
testCanEnableStaticRuntime(const std::string & jit_script)29 bool testCanEnableStaticRuntime(const std::string& jit_script) {
30   script::Module module("module");
31   module.define(jit_script);
32 
33   Method method = module.get_method("forward");
34   auto graph = module.get_method("forward").graph();
35 
36   // here we do not freeze graph
37   return canEnableStaticRuntime(graph);
38 }
39 
testCanEnableStaticRuntimeWithIR(const std::string & ir)40 bool testCanEnableStaticRuntimeWithIR(const std::string& ir) {
41   auto graph = std::make_shared<Graph>();
42   parseIR(ir, graph.get(), {});
43   return canEnableStaticRuntime(graph);
44 }
45 
testModuleHasOp(const std::string & jit_script,const char * op_name)46 bool testModuleHasOp(const std::string& jit_script, const char* op_name) {
47   script::Module module("module");
48   module.define(jit_script);
49 
50   return forwardHasOp(module, op_name);
51 }
52 
53 const auto reshape_inplace_script = R"JIT(
54   def forward(self, inp: Tensor, shape: List[int]):
55       a = inp + inp
56       b = a.reshape(shape)
57       c = b.sigmoid_()
58       d = c + c
59       e = a + a
60       f = b + b
61       return (d, e, f)
62 )JIT";
63 
64 const auto reshape_inplace_script_1 = R"JIT(
65   def forward(self, inp: Tensor, shape: List[int], flag: bool):
66     if flag:
67       a = inp + inp
68       b = a.reshape(shape)
69       c = b.sigmoid()
70     else:
71       a = inp * inp
72       b = a.sigmoid_()
73       c = b.reshape(shape)
74     d = c + c
75     e = a + a
76     f = b + b
77     return (d, e, f)
78 )JIT";
79 
80 const auto sigmoid_inplace_script = R"JIT(
81   def forward(self, inp: Tensor):
82       a = torch.sigmoid(inp, out=inp).clone()
83       return (a)
84 )JIT";
85 
86 } // namespace
87 
88 // Test that StaticModule::value_group groups values of the graph into
89 // 1) Inputs/Constants and their aliases 2) Outputs and their aliases.
TEST(StaticModule,ValueGroup)90 TEST(StaticModule, ValueGroup) {
91   const std::string src = R"IR(
92     graph(%input0 : Tensor, %input1 : Tensor):
93       # Constants.
94       %0 : int = prim::Constant[value=1]()
95       # Internal values.
96       %1 : Tensor = aten::add(%input0, %input1, %0)
97       # This includes aliases of output.
98       %2 : Tensor = aten::add(%input0, %1, %0)
99       # This includes output.
100       %3 : (Tensor) = prim::TupleConstruct(%2)
101       return (%3)
102     )IR";
103   auto input_graph = std::make_shared<torch::jit::Graph>();
104   torch::jit::parseIR(src, input_graph.get());
105   torch::jit::StaticModule sm(input_graph);
106   const Graph& graph = sm.graph();
107   std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
108   auto* root_block = sm.root_block();
109   const auto& value_group = sm.block_info(root_block).value_group();
110 
111   std::vector<const Value*> expected_input_aliases{
112       graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
113   for (auto* value : expected_input_aliases) {
114     EXPECT_TRUE(value_group.isExternalAlias(value));
115   }
116 
117   std::vector<const Value*> expected_output_aliases{
118       graph.outputs()[0], nodes[2]->output()};
119   for (auto* value : expected_output_aliases) {
120     EXPECT_TRUE(value_group.isOutputAlias(value));
121   }
122   EXPECT_FALSE(value_group.isAlwaysAlive(nodes[1]->output()));
123   EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[0]));
124   EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[1]));
125   EXPECT_TRUE(value_group.isAlwaysAlive(graph.outputs()[0]));
126 }
127 
TEST(StaticModule,IsOptimizableContainerType_NonOptimizableInputs)128 TEST(StaticModule, IsOptimizableContainerType_NonOptimizableInputs) {
129   // Cannot use out variants for list/tuple construction here because
130   // inputs are not produced by nodes with out variants.
131   const std::string src = R"JIT(
132         def forward(self, a, b):
133             a_alias = a.view(a.size())
134             non_optimizable_list = [a_alias]
135             non_optimizable_tuple = (b, )
136             return non_optimizable_list, non_optimizable_tuple
137     )JIT";
138 
139   auto sm = makeStaticModuleFromScript(src);
140   const auto& graph = sm.graph();
141   auto* root_block = sm.root_block();
142   const auto& block_info = sm.block_info(root_block);
143 
144   for (const Node* n : graph.nodes()) {
145     EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
146   }
147 }
148 
TEST(StaticModule,IsOptimizableContainerType_WrongType)149 TEST(StaticModule, IsOptimizableContainerType_WrongType) {
150   // Cannot use out variants for list/tuple construction here because
151   // types are not Tensors
152   const std::string src = R"JIT(
153         def forward(self, x: int, y: int):
154             a = 1 + x
155             b = 2 + y
156             non_optimizable_list = [a]
157             non_optimizable_tuple = (b, )
158             return non_optimizable_list, non_optimizable_tuple
159     )JIT";
160 
161   auto sm = makeStaticModuleFromScript(src);
162   const auto& graph = sm.graph();
163   auto* root_block = sm.root_block();
164   const auto& block_info = sm.block_info(root_block);
165 
166   for (const Node* n : graph.nodes()) {
167     EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
168   }
169 }
170 
TEST(StaticModule,IsOptimizableContainerType_CanUseOutVariant)171 TEST(StaticModule, IsOptimizableContainerType_CanUseOutVariant) {
172   // This container should be optimizable since aten::add has an
173   // out variant the container contains Tensors.
174   const std::string src = R"JIT(
175         def forward(self, x):
176             a = torch.relu(x)
177             optimizable_list = [a]
178             return optimizable_list
179     )JIT";
180   auto sm = makeStaticModuleFromScript(src);
181   const auto& graph = sm.graph();
182   auto* root_block = sm.root_block();
183   const auto& block_info = sm.block_info(root_block);
184 
185   for (const Node* n : graph.nodes()) {
186     if (n->kind() == c10::prim::ListConstruct) {
187       EXPECT_TRUE(block_info.node_is_optimizable_container_type(n));
188     } else {
189       EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
190     }
191   }
192 }
193 
194 // Test operator() with rvalue inputs
TEST(StaticModule,RValueInputs)195 TEST(StaticModule, RValueInputs) {
196   const std::string src = R"JIT(
197     def forward(self, x):
198         y = torch.relu(x)
199         return y.clone()
200   )JIT";
201   auto sm = makeStaticModuleFromScript(src);
202 
203   std::vector<IValue> input{at::randn({1})};
204 
205   auto expected = sm(input, {});
206   auto actual = sm(std::move(input), {});
207 
208   EXPECT_TRUE(expected.isTensor());
209   EXPECT_TRUE(actual.isTensor());
210   EXPECT_TRUE(expected.toTensor().equal(actual.toTensor()));
211 }
212 
TEST(StaticRuntime,ModuleHasOp)213 TEST(StaticRuntime, ModuleHasOp) {
214   EXPECT_TRUE(testModuleHasOp(reshape_inplace_script, "aten::sigmoid_"));
215   EXPECT_TRUE(testModuleHasOp(reshape_inplace_script_1, "aten::reshape"));
216   EXPECT_TRUE(testModuleHasOp(sigmoid_inplace_script, "aten::clone"));
217   EXPECT_FALSE(testModuleHasOp(reshape_inplace_script_1, "aten::add_"));
218 }
219 
TEST(StaticRuntime,ReplaceWithCopy_replaces_reshape)220 TEST(StaticRuntime, ReplaceWithCopy_replaces_reshape) {
221   auto ExpectToReplaceWithCopy = [](const std::string& jit_script) {
222     auto graph = getGraphFromScript(jit_script);
223     EXPECT_TRUE(graphHasOp(graph, "aten::reshape"));
224     EXPECT_FALSE(graphHasOp(graph, "static_runtime::reshape_copy"));
225 
226     ReplaceWithCopy(graph);
227 
228     // aten::reshape -> static_runtime::reshape_copy
229     EXPECT_FALSE(graphHasOp(graph, "aten::reshape"));
230     EXPECT_TRUE(graphHasOp(graph, "static_runtime::reshape_copy"));
231   };
232 
233   ExpectToReplaceWithCopy(R"JIT(
234     def forward(self, inp: Tensor, shape: List[int]):
235         a = inp.reshape(shape)
236         return (a)
237   )JIT");
238   ExpectToReplaceWithCopy(R"JIT(
239     def forward(self, inp: Tensor, shape: List[int]):
240         a = inp * 2
241         b = inp * 3
242         c = inp.reshape(shape)
243         return (a, b, c)
244   )JIT");
245   ExpectToReplaceWithCopy(R"JIT(
246     def forward(self, cond: bool, x):
247         if cond:
248             y = x.reshape(x.shape)
249         else:
250             y = x.clone()
251         return y.clone()
252   )JIT");
253 }
254 
TEST(StaticRuntime,ReplaceWithCopy_does_not_replace_reshape_if_input_has_writters)255 TEST(
256     StaticRuntime,
257     ReplaceWithCopy_does_not_replace_reshape_if_input_has_writters) {
258   auto ExpectNotToReplaceWithCopy = [](const std::string& jit_script) {
259     auto graph = getGraphFromScript(jit_script);
260     EXPECT_TRUE(graphHasOp(graph, "aten::reshape"));
261     EXPECT_FALSE(graphHasOp(graph, "static_runtime::reshape_copy"));
262 
263     ReplaceWithCopy(graph);
264 
265     // No Replacement
266     EXPECT_TRUE(graphHasOp(graph, "aten::reshape"));
267     EXPECT_FALSE(graphHasOp(graph, "static_runtime::reshape_copy"));
268   };
269 
270   ExpectNotToReplaceWithCopy(R"JIT(
271     def forward(self, inp: Tensor, shape: List[int]):
272         a = inp.reshape(shape)
273         inp *= 2
274         return (a)
275   )JIT");
276   ExpectNotToReplaceWithCopy(R"JIT(
277     def forward(self, inp: Tensor, shape: List[int]):
278         a = inp.reshape(shape)
279         a *= 2
280         return (a)
281   )JIT");
282   ExpectNotToReplaceWithCopy(R"JIT(
283     def forward(self, inp: Tensor, inp2: Tensor, shape: List[int]):
284         a = inp.reshape(shape)
285         a *= 2
286         b = a.reshape(shape)
287         return (b)
288   )JIT");
289   ExpectNotToReplaceWithCopy(R"JIT(
290     def forward(self, inp: Tensor, shape: List[int]):
291         a = inp.reshape(shape)
292         b = a.reshape(shape)
293         c = b.reshape(shape)
294         d = c.reshape(shape)
295         e = b.sigmoid_()
296         return (d)
297   )JIT");
298   ExpectNotToReplaceWithCopy(reshape_inplace_script);
299 }
300 
TEST(StaticRuntime,CanEnableStaticRuntime)301 TEST(StaticRuntime, CanEnableStaticRuntime) {
302   const auto while_script = R"JIT(
303     def forward(self, a: Tensor, x: int):
304         c = 0
305         while c < x:
306             a = a * a
307             c += 2
308         return a
309   )JIT";
310 
311   const auto for_script = R"JIT(
312     def forward(self, a: Tensor, x: int):
313         for c in range(x):
314             a = a * a
315         return a
316   )JIT";
317 
318   const auto if_script = R"JIT(
319     def forward(self, a: Tensor, b: bool):
320         if b:
321             return a
322         else:
323             return a * a
324   )JIT";
325 
326   const auto is_script_tensors = R"JIT(
327     def forward(self, a: Tensor, b: Tensor):
328         return a is b
329   )JIT";
330 
331   const auto is_script_none = R"JIT(
332     def forward(self, a: Optional[Tensor]):
333         return a is None
334   )JIT";
335 
336   const auto is_not_script_tensors = R"JIT(
337     def forward(self, a: Tensor, b: Tensor):
338         return a is not b
339   )JIT";
340 
341   const auto is_not_script_none = R"JIT(
342     def forward(self, a: Optional[Tensor]):
343         return a is not None
344   )JIT";
345 
346   EXPECT_TRUE(testCanEnableStaticRuntime(reshape_inplace_script));
347   EXPECT_TRUE(testCanEnableStaticRuntime(for_script));
348   EXPECT_TRUE(testCanEnableStaticRuntime(while_script));
349   EXPECT_TRUE(testCanEnableStaticRuntime(if_script));
350   EXPECT_FALSE(testCanEnableStaticRuntime(is_script_tensors));
351   EXPECT_TRUE(testCanEnableStaticRuntime(is_script_none));
352   EXPECT_FALSE(testCanEnableStaticRuntime(is_not_script_tensors));
353   EXPECT_TRUE(testCanEnableStaticRuntime(is_not_script_none));
354 
355 }
TEST(StaticRuntime,CanEnableStaticRuntimeCallMethod)356 TEST(StaticRuntime, CanEnableStaticRuntimeCallMethod) {
357   const auto call_method = R"IR(
358       graph(%x : Tensor):
359           %1 : Tensor = prim::CallMethod[name="offsets"](%x)
360           return (%1)
361   )IR";
362   EXPECT_FALSE(testCanEnableStaticRuntimeWithIR(call_method));
363 }
364 
TEST(StaticRuntime,CanEnableStaticRuntimeSubBlocks)365 TEST(StaticRuntime, CanEnableStaticRuntimeSubBlocks) {
366   const auto src = R"JIT(
367     def forward(self, a: Tensor, b: Tensor, cond: bool):
368         if cond:
369             # aten::__is__ on tensors is blocked
370             return a is b
371         return False
372   )JIT";
373 
374   EXPECT_FALSE(testCanEnableStaticRuntime(src));
375 }
376 
TEST(StaticRuntime,NestedOutput)377 TEST(StaticRuntime, NestedOutput) {
378   // dict of tuple of list
379   const auto nested_output_script_0 = R"JIT(
380     def forward(self, a, b):
381       c = (a + b).relu().nan_to_num().float()
382       d = a.flatten().nan_to_num() * b.flatten().nan_to_num()
383       e = d.float().relu()
384       f = ([c], [d])
385       g = ([e], [f])
386       return ({"prediction":(f, g)})
387   )JIT";
388 
389   // tuple of lists
390   const auto nested_output_script_1 = R"JIT(
391     def forward(self, a, b):
392       c = (a + b).relu().nan_to_num().float()
393       d = a.flatten().nan_to_num() * b.flatten().nan_to_num()
394       e = d.float().relu()
395       f = [c]
396       g = [e]
397       return (f, g)
398   )JIT";
399 
400   // list of tuple of dict
401   const auto nested_output_script_2 = R"JIT(
402     def forward(self, a, b):
403       c = (a + b).relu().nan_to_num().float()
404       d = b * c
405       e = a.flatten().nan_to_num() * b.flatten().nan_to_num()
406       f = e.float().relu()
407       g = ({"d": d}, {"b": b})
408       h = ({"e": e}, {"f": f})
409       return [g, h]
410   )JIT";
411 
412   // lit of dict
413   const auto nested_output_script_3 = R"JIT(
414     def forward(self, a, b):
415       c = (a + b).relu().nan_to_num().float()
416       d = b * c
417       e = a.flatten().nan_to_num() * b.flatten().nan_to_num()
418       f = e.float().relu()
419       g = {"d": d, "b": b}
420       h = {"e": e, "f": f}
421       return [g, h]
422   )JIT";
423 
424   auto run_test = [&](std::vector<int64_t> shapes) {
425     auto a = at::randn(shapes);
426     auto b = at::randn(shapes);
427 
428     std::vector<IValue> args{a, b};
429     testStaticRuntime(nested_output_script_0, args);
430     testStaticRuntime(nested_output_script_1, args);
431     testStaticRuntime(nested_output_script_2, args);
432     testStaticRuntime(nested_output_script_3, args);
433 
434     if (shapes.size() > 0 && shapes[0] != 0) {
435       shapes[0] *= 3;
436       testStaticRuntime(
437           nested_output_script_0, args, {at::randn(shapes), at::randn(shapes)});
438       testStaticRuntime(
439           nested_output_script_1, args, {at::randn(shapes), at::randn(shapes)});
440     }
441   };
442   run_test({2, 3, 1, 2});
443   run_test({2, 6});
444 }
445 
446 // test memory reuse
TEST(StaticRuntime,LongModel)447 TEST(StaticRuntime, LongModel) {
448   torch::jit::Module mod = getLongScriptModel();
449   auto a = torch::randn({2, 2});
450   auto b = torch::randn({2, 2});
451   auto c = torch::randn({2, 2});
452 
453   // run jit graph executor
454   std::vector<at::IValue> input_ivalues({a, b, c});
455   at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
456 
457   // run static runtime
458   std::vector<c10::IValue> input_tensors({a, b, c});
459   torch::jit::StaticModule smod(mod);
460   at::Tensor output_2 = smod(input_tensors, {}).toTensor();
461   smod.runtime().check_for_memory_leak();
462   EXPECT_TRUE(
463       torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7));
464 }
465 
TEST(StaticRuntime,TrivialModel)466 TEST(StaticRuntime, TrivialModel) {
467   torch::jit::Module mod = getTrivialScriptModel();
468   auto a = torch::randn({2, 2});
469   auto b = torch::randn({2, 2});
470   auto c = torch::randn({2, 2});
471 
472   // run jit graph executor
473   std::vector<at::IValue> input_ivalues({a, b, c});
474   at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
475 
476   // run static runtime
477   std::vector<c10::IValue> input_tensors({a, b, c});
478   torch::jit::StaticModule smod(mod);
479   at::Tensor output_2 = smod(input_tensors, {}).toTensor();
480   smod.runtime().check_for_memory_leak();
481   EXPECT_TRUE(
482       torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7));
483 }
484 
TEST(StaticRuntime,DeepWide)485 TEST(StaticRuntime, DeepWide) {
486   const int embedding_size = 32;
487   const int num_features = 50;
488   torch::jit::Module mod = getDeepAndWideSciptModel();
489   torch::jit::StaticModule smod(mod);
490 
491   for (int batch_size : {1, 8, 32}) {
492     for (int i = 0; i < 2; ++i) {
493       auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
494       auto user_emb = torch::randn({batch_size, 1, embedding_size});
495       auto wide = torch::randn({batch_size, num_features});
496 
497       // run jit graph executor
498       std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
499       auto output_1 = getTensor(mod.forward(inputs));
500 
501       // run static runtime
502       std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
503       auto outputs = smod(input_tensors, {}).toTupleRef().elements();
504       ASSERT_TRUE(outputs.size() > 0);
505       at::Tensor output_2 = outputs[0].toTensor();
506       smod.runtime().check_for_memory_leak();
507       EXPECT_TRUE(
508           torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
509     }
510   }
511 }
512 
TEST(StaticRuntime,KWargsAPI_1)513 TEST(StaticRuntime, KWargsAPI_1) {
514   const int embedding_size = 32;
515   const int num_features = 50;
516   auto module = getDeepAndWideSciptModel();
517   torch::jit::StaticModule smod(module);
518 
519   for (int batch_size : {1, 8, 32}) {
520     for (int i = 0; i < 2; ++i) {
521       auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
522       auto user_emb = torch::randn({batch_size, 1, embedding_size});
523       auto wide = torch::randn({batch_size, num_features});
524       {
525         std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
526 
527         // run jit graph executor
528         at::Tensor output_1 = getTensor(module.forward(inputs));
529 
530         // run static runtime
531         c10::IValue output_ivalue = smod(inputs, {});
532         smod.runtime().check_for_memory_leak();
533 
534         at::Tensor output_2 = getTensor(output_ivalue);
535         EXPECT_TRUE(
536             torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
537 
538         // check for output aliasing
539         EXPECT_EQ(output_ivalue.use_count(), 1);
540         output_ivalue = IValue();
541 
542         EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
543       }
544 
545       // check for input aliasing (deep & wide does not have ops
546       // that create aliases of input tensors)
547       EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
548       EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
549       EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
550     }
551   }
552 }
553 
TEST(StaticRuntime,KWargsAPI_2)554 TEST(StaticRuntime, KWargsAPI_2) {
555   const int embedding_size = 32;
556   const int num_features = 50;
557   auto module = getDeepAndWideSciptModel();
558   torch::jit::StaticModule smod(module);
559 
560   for (int batch_size : {1, 8, 32}) {
561     for (int i = 0; i < 2; ++i) {
562       auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
563       auto user_emb = torch::randn({batch_size, 1, embedding_size});
564       auto wide = torch::randn({batch_size, num_features});
565       {
566         // run jit graph executor
567         std::vector<at::IValue> args({ad_emb_packed, user_emb, wide});
568         at::Tensor output_1 = getTensor(module.forward(args));
569 
570         std::unordered_map<std::string, c10::IValue> kwargs(
571             {{"ad_emb_packed", ad_emb_packed},
572              {"user_emb", user_emb},
573              {"wide", wide}});
574 
575         // run static runtime
576         c10::IValue output_ivalue = smod(std::vector<IValue>{}, kwargs);
577         smod.runtime().check_for_memory_leak();
578 
579         at::Tensor output_2 = getTensor(output_ivalue);
580         EXPECT_TRUE(
581             torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
582 
583         // check for output aliasing
584         EXPECT_EQ(output_ivalue.use_count(), 1);
585         output_ivalue = IValue();
586 
587         EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
588       }
589 
590       EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
591       EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
592       EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
593     }
594   }
595 }
596 
TEST(StaticRuntime,KWargsAPI_Optional)597 TEST(StaticRuntime, KWargsAPI_Optional) {
598   const auto src = R"JIT(
599     def forward(self, x, y, z: Optional[Tensor] = None):
600         return x + y
601   )JIT";
602 
603   torch::jit::Module mod("mod");
604   mod.define(src);
605   torch::jit::StaticModule smod(mod);
606   const auto kwargs = std::unordered_map<std::string, IValue>{
607       {"x", at::randn({1})}, {"y", at::randn({1})}};
608 
609   auto expected = mod.forward({}, kwargs).toTensor();
610   auto actual = smod({}, kwargs).toTensor();
611 
612   EXPECT_TRUE(expected.equal(actual));
613 }
614 
TEST(StaticRuntime,CleanUpMemory)615 TEST(StaticRuntime, CleanUpMemory) {
616   const int embedding_size = 32;
617   const int num_features = 50;
618   torch::jit::Module mod = getDeepAndWideSciptModel();
619 
620   for (auto enable_out_variant : {true, false}) {
621     for (auto optimize_memory : {true, false}) {
622       for (auto manage_output_tensors : {true, false}) {
623         if (manage_output_tensors && !enable_out_variant) {
624           // when manage_output_tensors is enabled, enable_out_variant
625           // must be enabled too
626           continue;
627         }
628         if (optimize_memory && !enable_out_variant) {
629           // when optimize_memory is enabled, enable_out_variant must be
630           // enabled too
631           continue;
632         }
633         VLOG(1) << "enable_out_variant: " << enable_out_variant
634                 << ", optimize_memory: " << optimize_memory
635                 << ", manage_output_tensors: " << manage_output_tensors;
636         torch::jit::StaticModuleOptions opts{
637             enable_out_variant, optimize_memory, manage_output_tensors};
638         torch::jit::StaticModule smod(mod, false, opts);
639         torch::jit::StaticRuntime runtime(smod);
640 
641         for (int batch_size : {1, 8, 32}) {
642           for (int i = 0; i < 2; ++i) {
643             auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
644             auto user_emb = torch::randn({batch_size, 1, embedding_size});
645             auto wide = torch::randn({batch_size, num_features});
646 
647             // run jit graph executor
648             std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
649             auto output_1 = getTensor(mod.forward(inputs));
650 
651             // run static runtime
652             std::vector<c10::IValue> input_tensors(
653                 {ad_emb_packed, user_emb, wide});
654             auto outputs = runtime(input_tensors, {}).toTupleRef().elements();
655             ASSERT_TRUE(outputs.size() > 0);
656             auto output_2 = outputs[0].toTensor();
657             runtime.check_for_memory_leak();
658             EXPECT_TRUE(torch::allclose(
659                 output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
660             if (manage_output_tensors) {
661               runtime.deallocateOutputTensors();
662               runtime.checkOutputTensorMemoryLeaks();
663             }
664           }
665         }
666       }
667     }
668   }
669 }
670 
TEST(StaticRuntime,ManageOutputTensors)671 TEST(StaticRuntime, ManageOutputTensors) {
672   const std::string test_graph = R"IR(
673     graph(%0 : Tensor):
674       # With manage_output_tensor enabled, this tensor is managed.
675       %1 : Tensor = aten::abs(%0)
676       # The output container object is never managed.
677       %2 : (Tensor) = prim::TupleConstruct(%1)
678       return (%2)
679   )IR";
680   auto a = at::randn({2, 2});
681   auto b = at::randn({3, 6});
682   std::vector<at::IValue> args{a};
683   std::vector<at::IValue> args2{b};
684   testStaticRuntime(test_graph, args);
685   testStaticRuntime(test_graph, args, args2);
686 }
687 
TEST(StaticRuntime,ManageOutputTensorsReturnsOutputContainingManagedOutputTensor)688 TEST(
689     StaticRuntime,
690     ManageOutputTensorsReturnsOutputContainingManagedOutputTensor) {
691   const std::string test_graph = R"IR(
692     graph(%0 : Tensor):
693       # With manage_output_tensor enabled, this tensor is managed.
694       %1 : Tensor = aten::abs(%0)
695       # The output container object is never managed.
696       %2 : (Tensor) = prim::TupleConstruct(%1)
697       return (%2)
698   )IR";
699   auto g = std::make_shared<torch::jit::Graph>();
700   torch::jit::parseIR(test_graph, g.get());
701   torch::jit::StaticModuleOptions opts{
702       /*enable_out_variant=*/true,
703       /*optimize_memory=*/true,
704       /*manage_output_tensors=*/true};
705   auto a = at::randn({2, 2});
706   std::vector<at::IValue> args{a};
707   torch::jit::StaticModule smod(g, opts);
708   torch::jit::StaticRuntime runtime(smod);
709   // Profile run.
710   {
711     IValue tuple = runtime(args, {});
712     ASSERT_TRUE(tuple.isTuple());
713     ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
714     // Do not manage input value.
715     EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
716     // Do not manage direct output value.
717     EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
718     IValue element = tuple.toTupleRef().elements()[0];
719     // Tensor to be managed, but not yet from the profile run.
720     EXPECT_FALSE(runtime.isManagedOutputTensor(element));
721     tuple = IValue();
722     runtime.deallocateOutputTensors();
723     runtime.checkOutputTensorMemoryLeaks();
724   }
725   // Second run that manages output tensors.
726   {
727     IValue tuple = runtime(args, {});
728     ASSERT_TRUE(tuple.isTuple());
729     ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
730     // Do not manage input value.
731     EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
732     // Do not manage direct output value.
733     EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
734     IValue element = tuple.toTupleRef().elements()[0];
735     // Tensor to be managed, but not yet from the profile run.
736     EXPECT_TRUE(runtime.isManagedOutputTensor(element));
737     tuple = IValue();
738     runtime.deallocateOutputTensors();
739     runtime.checkOutputTensorMemoryLeaks();
740   }
741 }
742 
TEST(StaticRuntime,ManageOutputTensorsWithDeallocateOutputTensors)743 TEST(StaticRuntime, ManageOutputTensorsWithDeallocateOutputTensors) {
744   const int embedding_size = 32;
745   const int num_features = 50;
746   torch::jit::Module mod = getDeepAndWideSciptModel();
747 
748   torch::jit::StaticModuleOptions opts{
749       /*enable_out_variant=*/true,
750       /*optimize_memory=*/true,
751       /*manage_output_tensors=*/true};
752   torch::jit::StaticModule smod(mod, false, opts);
753   torch::jit::StaticRuntime runtime(smod);
754   // Reenter the runtime with the input with the same shape/different shapes.
755   for (int batch_size : {8, 8, 24, 8}) {
756     auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
757     auto user_emb = torch::randn({batch_size, 1, embedding_size});
758     auto wide = torch::randn({batch_size, num_features});
759     std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
760     runtime(input_tensors, {});
761     runtime.check_for_memory_leak();
762     runtime.deallocateOutputTensors();
763     runtime.checkOutputTensorMemoryLeaks();
764   }
765 }
766 
TEST(StaticRuntime,ManageOutputTensorsWithoutDeallocateOutputTensors)767 TEST(StaticRuntime, ManageOutputTensorsWithoutDeallocateOutputTensors) {
768   const int embedding_size = 32;
769   const int num_features = 50;
770   torch::jit::Module mod = getDeepAndWideSciptModel();
771 
772   torch::jit::StaticModuleOptions opts{
773       /*enable_out_variant=*/true,
774       /*optimize_memory=*/true,
775       /*manage_output_tensors=*/true};
776   torch::jit::StaticModule smod(mod, false, opts);
777   torch::jit::StaticRuntime runtime(smod);
778   int batch_size = 8;
779   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
780   auto user_emb = torch::randn({batch_size, 1, embedding_size});
781   auto wide = torch::randn({batch_size, num_features});
782   std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
783   // Profile run.
784   runtime(input_tensors, {});
785   runtime.deallocateOutputTensors();
786   // Run again to allocate output Tensors without deallocating them.
787   runtime(input_tensors, {});
788   // Memory leak checking fails.
789   EXPECT_THROW(runtime.checkOutputTensorMemoryLeaks(), std::exception);
790   // Calling the runtime without deallocation fails too.
791   EXPECT_THROW(runtime(input_tensors, {}), std::exception);
792   // After deallocation, everything works fine.
793   runtime.deallocateOutputTensors();
794   runtime.checkOutputTensorMemoryLeaks();
795   runtime(input_tensors, {});
796 }
797 
TEST(StaticRuntime,DisableManageOutputTensors)798 TEST(StaticRuntime, DisableManageOutputTensors) {
799   const std::string test_graph = R"IR(
800     graph(%0 : Tensor):
801       # With manage_output_tensor enabled, this tensor is managed.
802       %1 : Tensor = aten::abs(%0)
803       # The output container object is never managed.
804       %2 : (Tensor) = prim::TupleConstruct(%1)
805       return (%2)
806   )IR";
807   auto g = std::make_shared<torch::jit::Graph>();
808   torch::jit::parseIR(test_graph, g.get());
809   torch::jit::StaticModuleOptions opts{
810       /*enable_out_variant=*/true,
811       /*optimize_memory=*/true,
812       /*manage_output_tensors=*/true};
813   auto a = at::randn({2, 2});
814   std::vector<at::IValue> args{a};
815   torch::jit::StaticModule smod(g, opts);
816   torch::jit::StaticRuntime runtime(smod);
817   // Profile run.
818   {
819     IValue tuple = runtime(args, {});
820     IValue element = tuple.toTupleRef().elements()[0];
821     EXPECT_FALSE(runtime.isManagedOutputTensor(element));
822     tuple = IValue();
823     runtime.deallocateOutputTensors();
824     runtime.checkOutputTensorMemoryLeaks();
825   }
826   // Second run that manages output tensors.
827   {
828     IValue tuple = runtime(args, {});
829     IValue element = tuple.toTupleRef().elements()[0];
830     EXPECT_TRUE(runtime.isManagedOutputTensor(element));
831     tuple = IValue();
832     runtime.deallocateOutputTensors();
833     runtime.checkOutputTensorMemoryLeaks();
834   }
835 
836   // Reset the runtime and start profiling again.
837   runtime.disableManageOutputTensors();
838 
839   IValue copied_output_tensor;
840   IValue original_output_tensor;
841   // New profile run.
842   {
843     IValue tuple = runtime(args, {});
844     IValue element = tuple.toTupleRef().elements()[0];
845     EXPECT_FALSE(runtime.isManagedOutputTensor(element));
846     copied_output_tensor = element.deepcopy();
847     original_output_tensor = element;
848     tuple = IValue();
849     // No-op since manage_output_tensor is disabled now.
850     runtime.deallocateOutputTensors();
851     runtime.checkOutputTensorMemoryLeaks();
852   }
853   // Ensure that `original_output_tensor` is no longer managed: even after
854   // calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
855   // contains a valid value.
856   EXPECT_TRUE(
857       original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
858 
859   // Ensure that the second optimized run does not manage the output tensor
860   // either.
861   {
862     IValue tuple = runtime(args, {});
863     IValue element = tuple.toTupleRef().elements()[0];
864     EXPECT_FALSE(runtime.isManagedOutputTensor(element));
865     copied_output_tensor = element.deepcopy();
866     original_output_tensor = element;
867     tuple = IValue();
868     // No-op since manage_output_tensor is disabled now.
869     runtime.deallocateOutputTensors();
870     runtime.checkOutputTensorMemoryLeaks();
871   }
872   // Ensure that `original_output_tensor` is no longer managed: even after
873   // calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
874   // contains a valid value.
875   EXPECT_TRUE(
876       original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
877 }
878 
TEST(StaticRuntime,FusionPass)879 TEST(StaticRuntime, FusionPass) {
880   const int embedding_size = 32;
881   const int num_features = 50;
882   for (int batch_size : {1, 8, 32}) {
883     for (int i = 0; i < 2; ++i) {
884       torch::jit::Module module = getDeepAndWideSciptModel();
885       auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
886       auto user_emb = torch::randn({batch_size, 1, embedding_size});
887       auto wide = torch::randn({batch_size, num_features});
888 
889       // run jit graph executor
890       std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
891       auto output_1 = getTensor(module.forward(inputs));
892 
893       Method method = module.get_method("forward");
894       auto graph = method.graph();
895       fuseStaticSubgraphs(graph, 2);
896       bool hit = false;
897       for (const auto& n : module.get_method("forward").graph()->nodes()) {
898         if (n->kind() == torch::jit::prim::StaticSubgraph) {
899           hit = true;
900         }
901       }
902       EXPECT_TRUE(hit);
903       auto output_2 = getTensor(module.forward(inputs));
904       EXPECT_TRUE(
905           torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
906     }
907   }
908 }
909 
createProcessedNodeInputs(c10::ArrayRef<uint16_t> inputs)910 static ProcessedNodeInputs createProcessedNodeInputs(
911     c10::ArrayRef<uint16_t> inputs) {
912   ProcessedNodeInputs result(inputs.size());
913   for (const auto idx : c10::irange(inputs.size())) {
914     result[idx] = inputs[idx];
915   }
916   return result;
917 }
918 
TEST(ProcessedNode,VerifyNoMemoryOverlapWithImmutableInputsWithImmutableArguments)919 TEST(
920     ProcessedNode,
921     VerifyNoMemoryOverlapWithImmutableInputsWithImmutableArguments) {
922   const auto sigmoid_script = R"JIT(
923     def forward(self, inp: Tensor):
924         b = torch.sigmoid(inp).clone()
925         return (b)
926   )JIT";
927   script::Module module("module");
928   // Not using out= variant.
929   module.define(sigmoid_script);
930   torch::jit::StaticModule smodule(module);
931   Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
932   std::array<IValue, 2> values = {torch::randn({2, 3}), torch::randn({3, 1})};
933   ProcessedFunction fn(
934       sigmoid_node,
935       /*enable_out_variant=*/true,
936       /*check_memory_overlap=*/false);
937   StaticNodeInfo static_node_info(
938       sigmoid_node, &fn, createProcessedNodeInputs({0}), 1);
939   ProcessedNode pnode(static_node_info, values.data());
940   EXPECT_TRUE(pnode.verify_no_memory_overlap(/* force_check*/ true));
941 
942   pnode.Output(0) = values[0];
943   EXPECT_FALSE(pnode.verify_no_memory_overlap(/* force_check*/ true));
944 }
945 
TEST(ProcessedNode,VerifyNoMemoryOverlapWithImmutableInputsWithInplaceOps)946 TEST(ProcessedNode, VerifyNoMemoryOverlapWithImmutableInputsWithInplaceOps) {
947   script::Module module("module");
948   // Using out= variant.
949   module.define(sigmoid_inplace_script);
950   torch::jit::StaticModule smodule(module);
951   Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
952   std::array<IValue, 2> values = {torch::randn({2, 3}), torch::randn({3, 1})};
953   ProcessedFunction fn(
954       sigmoid_node,
955       /*enable_out_variant=*/true,
956       /*check_memory_overlap=*/false);
957   StaticNodeInfo static_node_info(
958       sigmoid_node, &fn, createProcessedNodeInputs({0}), 1);
959   ProcessedNode pnode(static_node_info, values.data());
960 
961   ASSERT_EQ(&pnode.Output(0), &values[1]);
962   EXPECT_TRUE(pnode.verify_no_memory_overlap());
963 
964   pnode.Output(0) = values[0];
965   EXPECT_TRUE(pnode.verify_no_memory_overlap());
966 }
967 
TEST(ProcessedNode,VerifyNoMemoryOverlapWithOverlappingOutputs)968 TEST(ProcessedNode, VerifyNoMemoryOverlapWithOverlappingOutputs) {
969   auto g = std::make_shared<torch::jit::Graph>();
970   torch::jit::parseIR(
971       R"IR(
972     graph(%0):
973       %1 : Tensor, %2 : Tensor = prim::ListUnpack(%0)
974       return (%1, %2))IR",
975       g.get());
976   torch::jit::StaticModule smodule(g);
977   Node* list_unpack_node = getNodeWithKind(smodule, "prim::ListUnpack");
978   {
979     std::array<IValue, 3> values = {
980         at::randn({2, 3}), at::empty({1, 3}), at::empty({4, 5})};
981     ProcessedFunction fn(
982         list_unpack_node,
983         /*enable_out_variant=*/true,
984         /*check_memory_overlap */ false);
985     StaticNodeInfo list_unpack_static_node_info(
986         list_unpack_node, &fn, createProcessedNodeInputs({0}), 1);
987     ProcessedNode list_unpack_pnode(
988         list_unpack_static_node_info, values.data());
989     ASSERT_EQ(list_unpack_pnode.outputs().size(), 2);
990     EXPECT_TRUE(
991         list_unpack_pnode.verify_no_memory_overlap(/* force_check*/ true));
992   }
993   {
994     std::array<IValue, 3> values = {
995         at::randn({2, 3}), at::empty({1, 3}), at::empty({4, 5})};
996     ProcessedFunction fn(
997         list_unpack_node,
998         /*enable_out_variant=*/true,
999         /*check_memory_overlap */ false);
1000     StaticNodeInfo list_unpack_static_node_info(
1001         list_unpack_node, &fn, createProcessedNodeInputs({0}), 1);
1002     ProcessedNode list_unpack_pnode(
1003         list_unpack_static_node_info, values.data());
1004     auto b = at::randn({2, 3});
1005     list_unpack_pnode.Output(0) = b;
1006     list_unpack_pnode.Output(1) = b;
1007     EXPECT_FALSE(
1008         list_unpack_pnode.verify_no_memory_overlap(/* force_check*/ true));
1009   }
1010 }
1011 
1012 namespace test {
bad_add(const at::Tensor & self,int64_t b)1013 at::Tensor bad_add(const at::Tensor& self, int64_t b) {
1014   if (b == 0) {
1015     return self;
1016   }
1017   return at::native::add(self, b);
1018 }
1019 
good_add(const at::Tensor & self,int64_t b)1020 at::Tensor good_add(const at::Tensor& self, int64_t b) {
1021   if (b == 0) {
1022     return self;
1023   }
1024   return at::native::add(self, b);
1025 }
1026 } // namespace test
1027 
1028 // test::bad_add has the schema with incorrect alias annotation.
1029 // test::good_add has the correct alias annotation.
TORCH_LIBRARY_FRAGMENT(test,m)1030 TORCH_LIBRARY_FRAGMENT(test, m) {
1031   m.def("bad_add(Tensor self, int b=0) -> Tensor");
1032   m.def("good_add(Tensor(a) self, int b=0) -> Tensor(a)");
1033 }
TORCH_LIBRARY_IMPL(test,CPU,m)1034 TORCH_LIBRARY_IMPL(test, CPU, m) {
1035   m.impl("bad_add", ::test::bad_add);
1036   m.impl("good_add", ::test::good_add);
1037 }
1038 
TEST(StaticRuntime,BadSchemaAliasInfo)1039 TEST(StaticRuntime, BadSchemaAliasInfo) {
1040   FLAGS_static_runtime_disable_debug_memory_overlap_check = true;
1041   const std::string src = R"IR(
1042       graph(%x: Tensor, %s: int):
1043           %c0 : int = prim::Constant[value=0]()
1044           %c1 : int = prim::Constant[value=1]()
1045           %a = aten::add(%x, %x, %c1)
1046           %b1 = test::bad_add(%a, %s) # b1 aliases a
1047           %t : (Tensor) = prim::TupleConstruct(%b1)
1048           return (%t)
1049   )IR";
1050 
1051   const auto x1 = at::randn({2, 2});
1052   // big enough to trigger resize of the internal buffer
1053   const auto x2 = at::randn({3, 6});
1054   testStaticRuntime(src, {x1, 0}, {x2, 10});
1055   // This test doesn't pass yet. This is the corner case mentioned in Step 2 of
1056   // [Check and correct bad schema alias info at runtime]
1057   // testStaticRuntime(src, {x1, 10}, {x2, 0});
1058   FLAGS_static_runtime_disable_debug_memory_overlap_check = false;
1059 }
1060 
1061 // This test repeats the last test, but with the correct schema alias
1062 // annotations
TEST(StaticRuntime,GoodSchemaAliasInfo)1063 TEST(StaticRuntime, GoodSchemaAliasInfo) {
1064   // comment out the prim::TupleConstruct repro the failure of
1065   // DCHECK(!isManagedOutputTensor(*outputs_[0]));
1066   const std::string src = R"IR(
1067       graph(%x: Tensor, %s: int):
1068           %c0 : int = prim::Constant[value=0]()
1069           %c1 : int = prim::Constant[value=1]()
1070           %a = aten::add(%x, %x, %c1)
1071           %b1 = test::good_add(%a, %s) # b1 aliases a
1072           # return (%b1)
1073           %t : (Tensor) = prim::TupleConstruct(%b1)
1074           return (%t)
1075   )IR";
1076 
1077   const auto x1 = at::randn({2, 2});
1078   // big enough to trigger resize of the internal buffer
1079   const auto x2 = at::randn({3, 6});
1080   testStaticRuntime(src, {x1, 0}, {x2, 10});
1081   testStaticRuntime(src, {x1, 10}, {x2, 0});
1082 }
1083 
TEST(ProcessedFunction,ProcessedFunction)1084 TEST(ProcessedFunction, ProcessedFunction) {
1085   const auto script = R"JIT(
1086     def forward(self, inp: Tensor):
1087         b = torch.sigmoid(inp).clone()
1088         c = torch.transpose(b, 0, 1)
1089         return (c)
1090   )JIT";
1091   script::Module module("module");
1092   module.define(script);
1093   torch::jit::StaticModule smodule(module);
1094 
1095   Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
1096   ProcessedFunction sigmoid_fn(
1097       sigmoid_node,
1098       /*enable_out_variant=*/true,
1099       /*check_memory_overlap=*/false);
1100   EXPECT_EQ(sigmoid_fn.kind(), ProcessedFunction::Kind::kOutVariant);
1101   EXPECT_FALSE(sigmoid_fn.checkMemoryOverlap());
1102 
1103   Node* transpose_node = getNodeWithKind(smodule, "aten::transpose");
1104   ProcessedFunction transpose_fn(
1105       transpose_node,
1106       /*enable_out_variant=*/true,
1107       /*check_memory_overlap=*/false);
1108   EXPECT_EQ(transpose_fn.kind(), ProcessedFunction::Kind::kNativeFunction);
1109   EXPECT_FALSE(transpose_fn.checkMemoryOverlap());
1110 }
1111 
TEST(ManagedTensorRanges,NoAliases)1112 TEST(ManagedTensorRanges, NoAliases) {
1113   const std::string src = R"IR(
1114     graph(%x : Tensor):
1115         %y : Tensor = aten::mul(%x, %x)
1116         %z : Tensor = aten::mul(%y, %x)
1117         %output : Tensor = aten::mul(%z, %z)
1118         return (%output)
1119   )IR";
1120   auto graph = std::make_shared<Graph>();
1121   std::unordered_map<std::string, Value*> vmap;
1122   parseIR(src, graph.get(), vmap);
1123 
1124   auto* y = vmap["y"];
1125   auto* z = vmap["z"];
1126 
1127   FastSet<const Value*> managed_tensors = {y, z};
1128   AliasDb alias_db(graph);
1129   auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
1130 
1131   std::vector<Node*> nodes(
1132       graph->block()->nodes().begin(), graph->block()->nodes().end());
1133   ASSERT_EQ(nodes.size(), 3);
1134 
1135   EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[0]));
1136 
1137   EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[1]));
1138   EXPECT_EQ(
1139       ranges.availableTensorValuesAfterNode(nodes[1]),
1140       std::vector<const Value*>{y});
1141 
1142   EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[2]));
1143   EXPECT_EQ(
1144       ranges.availableTensorValuesAfterNode(nodes[2]),
1145       std::vector<const Value*>{z});
1146 }
1147 
TEST(ManagedTensorRanges,AliasExtendingLifetimes)1148 TEST(ManagedTensorRanges, AliasExtendingLifetimes) {
1149   const std::string src = R"IR(
1150     graph(%x : Tensor):
1151         %y : Tensor = aten::mul(%x, %x)
1152         %y_size : int[] = aten::size(%y)
1153         %z1 : Tensor = aten::mul(%y, %y)
1154         %y_alias : Tensor = aten::view(%y, %y_size)
1155         %z2 : Tensor = aten::mul(%y_alias, %y_alias)
1156         %output : Tensor = aten::mul(%z1, %z2)
1157         return (%output)
1158   )IR";
1159   auto graph = std::make_shared<Graph>();
1160   std::unordered_map<std::string, Value*> vmap;
1161   parseIR(src, graph.get(), vmap);
1162 
1163   auto* y = vmap["y"];
1164   auto* z1 = vmap["z1"];
1165   auto* z2 = vmap["z2"];
1166 
1167   FastSet<const Value*> managed_tensors = {y, z1, z2};
1168   AliasDb alias_db(graph);
1169   auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
1170 
1171   std::vector<Node*> nodes(
1172       graph->block()->nodes().begin(), graph->block()->nodes().end());
1173   ASSERT_EQ(nodes.size(), 6);
1174 
1175   for (const auto i : c10::irange(4)) {
1176     EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[i]));
1177   }
1178 
1179   EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[4]));
1180   EXPECT_EQ(
1181       ranges.availableTensorValuesAfterNode(nodes[4]),
1182       std::vector<const Value*>{y});
1183 
1184   EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[5]));
1185   const auto& available_after_5 =
1186       ranges.availableTensorValuesAfterNode(nodes[5]);
1187   // We don't care about the order, so convert to set. But make sure
1188   // there are no duplicates.
1189   FastSet<const Value*> available_after_5_set(
1190       available_after_5.begin(), available_after_5.end());
1191   EXPECT_EQ(available_after_5_set.size(), available_after_5.size());
1192   EXPECT_EQ(available_after_5_set, FastSet<const Value*>({z1, z2}));
1193 }
1194 
TEST(ManagedTensorRanges,LifetimeOverlap)1195 TEST(ManagedTensorRanges, LifetimeOverlap) {
1196   const std::string src = R"IR(
1197     graph(%a : Tensor):
1198         %b : Tensor = aten::mul(%a, %a)
1199         %c : Tensor = aten::mul(%b, %b)
1200         %c_size : int[] = aten::size(%c)
1201         %c_alias : Tensor = aten::view(%c, %c_size)
1202         %d : Tensor = aten::mul(%a, %a)
1203         %e : Tensor = aten::mul(%c_alias, %c_alias)
1204         %output : Tensor = aten::mul(%e, %e)
1205         return (%output)
1206   )IR";
1207   auto graph = std::make_shared<Graph>();
1208   std::unordered_map<std::string, Value*> vmap;
1209   parseIR(src, graph.get(), vmap);
1210   auto* b = vmap["b"];
1211   auto* c = vmap["c"];
1212   auto* d = vmap["d"];
1213   auto* e = vmap["e"];
1214 
1215   AliasDb alias_db(graph);
1216   auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d, e});
1217   const std::vector<std::pair<Value*, Value*>> overlapping_values{
1218       {b, c}, {c, d}, {c, e}};
1219 
1220   const std::vector<std::pair<Value*, Value*>> disjoint_values{{b, d}, {b, e}};
1221 
1222   for (const auto& values : overlapping_values) {
1223     EXPECT_TRUE(ranges.lifetimesOverlap(values.first, values.second));
1224     EXPECT_TRUE(ranges.lifetimesOverlap(values.second, values.first));
1225   }
1226   for (const auto& values : disjoint_values) {
1227     EXPECT_FALSE(ranges.lifetimesOverlap(values.first, values.second));
1228     EXPECT_FALSE(ranges.lifetimesOverlap(values.second, values.first));
1229   }
1230 }
1231 
TEST(ManagedTensorRanges,OverlappingLifetimesContainers)1232 TEST(ManagedTensorRanges, OverlappingLifetimesContainers) {
1233   const std::string src = R"IR(
1234     graph(%a : Tensor):
1235         %b : Tensor = aten::mul(%a, %a)
1236         %c : Tensor = aten::mul(%b, %b)
1237         %tuple : (Tensor, Tensor) = prim::TupleConstruct(%b, %c)
1238         %b_alias : Tensor, %c_alias : Tensor = prim::TupleUnpack(%tuple)
1239         %d : Tensor = aten::mul(%b_alias, %c_alias)
1240         %output : Tensor = aten::mul(%d, %d)
1241         return (%output)
1242   )IR";
1243   auto graph = std::make_shared<Graph>();
1244   std::unordered_map<std::string, Value*> vmap;
1245   parseIR(src, graph.get(), vmap);
1246   auto* b = vmap["b"];
1247   auto* c = vmap["c"];
1248   auto* d = vmap["d"];
1249 
1250   AliasDb alias_db(graph);
1251   auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d});
1252 
1253   EXPECT_TRUE(ranges.lifetimesOverlap(b, c));
1254   EXPECT_TRUE(ranges.lifetimesOverlap(b, d));
1255   EXPECT_TRUE(ranges.lifetimesOverlap(c, d));
1256 }
1257 
TEST(ManagedTensorRanges,OverlappingLifetimesOutputs)1258 TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
1259   const std::string src = R"IR(
1260     graph(%a : Tensor):
1261         %output : Tensor = aten::mul(%a, %a)
1262         %b : Tensor = aten::mul(%a, %a)
1263         return (%output)
1264   )IR";
1265   auto graph = std::make_shared<Graph>();
1266   std::unordered_map<std::string, Value*> vmap;
1267   parseIR(src, graph.get(), vmap);
1268   auto* b = vmap["b"];
1269   auto* output = vmap["output"];
1270 
1271   AliasDb alias_db(graph);
1272   auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, output});
1273 
1274   EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
1275 }
1276 
1277 namespace {
1278 
1279 // For checking the correctness of assignStorageToManageTensors, the following
1280 // conditions must hold
1281 // 1. All managed tensors are assigned to some storage group, and a tensor
1282 //    may not be assigned to more than 1 storage group.
1283 // 2. Managed tensors with overlapping lifetimes should not be in the same
1284 //    storage group.
1285 // 3. The number of reused tensors is >= min_reused_tensors.
checkStorageGroups(const std::vector<StorageGroup> & storage_groups,const ManagedTensorRanges & ranges,const FastMap<const Value *,at::Tensor * > & tensor_value_to_tensor,size_t min_reused_tensors)1286 void checkStorageGroups(
1287     const std::vector<StorageGroup>& storage_groups,
1288     const ManagedTensorRanges& ranges,
1289     const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor,
1290     size_t min_reused_tensors) {
1291   // Some extra bookkeeping; construct the set of managed Tensor* and
1292   // invert the tensor_value_to_tensor map. StorageGroup stores
1293   // Tensor*, so this will make everything a little easier.
1294   FastMap<at::Tensor*, const Value*> tensor_to_tensor_value;
1295   FastSet<at::Tensor*> managed_tensors;
1296   for (auto& key_value : tensor_value_to_tensor) {
1297     ASSERT_EQ(
1298         tensor_to_tensor_value.find(key_value.second),
1299         tensor_to_tensor_value.end());
1300     tensor_to_tensor_value.emplace(key_value.second, key_value.first);
1301     managed_tensors.insert(key_value.second);
1302   }
1303 
1304   // Condition (1)
1305   FastSet<at::Tensor*> actual_assigned_tensors;
1306   for (const auto& storage_group : storage_groups) {
1307     for (auto* tensor : storage_group.group()) {
1308       ASSERT_EQ(
1309           actual_assigned_tensors.find(tensor), actual_assigned_tensors.end());
1310       actual_assigned_tensors.insert(tensor);
1311     }
1312   }
1313   ASSERT_EQ(actual_assigned_tensors, managed_tensors);
1314 
1315   // Condition (2)
1316   size_t num_reused = 0;
1317   for (const auto& storage_group : storage_groups) {
1318     const auto& group = storage_group.group();
1319     num_reused += group.size() - 1;
1320     for (const auto i : c10::irange(group.size() - 1)) {
1321       for (const auto j : c10::irange(i + 1, group.size())) {
1322         const auto* v1 = tensor_to_tensor_value.at(group[i]);
1323         const auto* v2 = tensor_to_tensor_value.at(group[j]);
1324         EXPECT_FALSE(ranges.lifetimesOverlap(v1, v2));
1325       }
1326     }
1327   }
1328 
1329   // Condition (3)
1330   EXPECT_GE(num_reused, min_reused_tensors);
1331 }
1332 
1333 // A convenience function for testing assignStorageToManagedTensors. It
1334 // takes in an IR graph as well as a map from managed tensor name to tensor
1335 // value. It constructs all of the necessary data structures, invokes
1336 // assignStorageToManageTensors, and verifies correctness with
1337 // checkStorageGroups.
testAssignStorageToManagedTensors(const std::string & src,FastMap<std::string,at::Tensor> managed_tensor_name_to_tensor,size_t min_reused_tensors)1338 void testAssignStorageToManagedTensors(
1339     const std::string& src,
1340     FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor,
1341     size_t min_reused_tensors) {
1342   auto graph = std::make_shared<Graph>();
1343   std::unordered_map<std::string, Value*> vmap;
1344   parseIR(src, graph.get(), vmap);
1345 
1346   FastSet<const Value*> managed_tensor_values;
1347   FastMap<const Value*, at::Tensor*> tensor_value_to_tensor;
1348 
1349   for (auto& key_value : managed_tensor_name_to_tensor) {
1350     const auto& tensor_name = key_value.first;
1351     auto vmap_it = vmap.find(tensor_name);
1352     ASSERT_TRUE(vmap_it != vmap.end());
1353     managed_tensor_values.insert(vmap_it->second);
1354     tensor_value_to_tensor.emplace(vmap_it->second, &key_value.second);
1355   }
1356   ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size());
1357 
1358   AliasDb alias_db(graph);
1359   auto ranges =
1360       ManagedTensorRanges(*graph->block(), alias_db, managed_tensor_values);
1361   auto groups = assignStorageToManagedTensors(
1362       graph->block()->nodes(), ranges, tensor_value_to_tensor);
1363 
1364   checkStorageGroups(
1365       groups, ranges, tensor_value_to_tensor, min_reused_tensors);
1366 }
1367 
1368 } // namespace
1369 
TEST(AssignStorageToManagedTensors,NoAliases)1370 TEST(AssignStorageToManagedTensors, NoAliases) {
1371   const auto src = R"IR(
1372     graph(%a : Tensor):
1373       %b : Tensor = aten::mul(%a, %a)
1374       %c : Tensor = aten::mul(%b, %b)
1375       %d : Tensor = aten::mul(%c, %c)
1376       %e : Tensor = aten::mul(%b, %d)
1377       %output : Tensor = aten::mul(%e, %e)
1378       return (%output)
1379   )IR";
1380   FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor{
1381       {"b", at::randn({1})},
1382       {"c", at::randn({1})},
1383       {"d", at::randn({1})},
1384       {"e", at::randn({1})}};
1385   const size_t min_reused_tensors = 1;
1386   testAssignStorageToManagedTensors(
1387       src, std::move(managed_tensor_name_to_tensor), min_reused_tensors);
1388 }
1389 
TEST(AssignStorageToManagedTensors,Aliases)1390 TEST(AssignStorageToManagedTensors, Aliases) {
1391   const auto src = R"IR(
1392     graph(%a : Tensor):
1393       %b : Tensor = aten::mul(%a, %a)
1394       %c : Tensor = aten::mul(%b, %b)
1395       %d : Tensor = aten::mul(%c, %c)
1396       %c_size : int[] = aten::size(%c)
1397       %c_alias : Tensor = aten::view(%c, %c_size)
1398       %e : Tensor = aten::mul(%b, %d)
1399       %f : Tensor = aten::mul(%c_alias, %c_alias)
1400       %output : Tensor = aten::mul(%e, %f)
1401       return (%output)
1402   )IR";
1403   FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor{
1404       {"b", at::randn({1})},
1405       {"c", at::randn({1})},
1406       {"d", at::randn({1})},
1407       {"e", at::randn({1})},
1408       {"f", at::randn({1})}};
1409   const size_t min_reused_tensors = 1;
1410   testAssignStorageToManagedTensors(
1411       src, std::move(managed_tensor_name_to_tensor), min_reused_tensors);
1412 }
1413 
1414 namespace {
TORCH_LIBRARY_FRAGMENT(static_runtime_tests,m)1415 TORCH_LIBRARY_FRAGMENT(static_runtime_tests, m) {
1416   m.def(torch::schema(
1417       "static_runtime_tests::variadic_outputs(Tensor a) -> ...",
1418       at::AliasAnalysisKind::PURE_FUNCTION));
1419 }
1420 } // namespace
1421 
TEST(AssignStorageToManagedTensors,MultipleUnused)1422 TEST(AssignStorageToManagedTensors, MultipleUnused) {
1423   const auto src = R"IR(
1424     graph(%a : Tensor):
1425         %z : Tensor = aten::mul(%a, %a)
1426         %out: Tensor = aten::mul(%z, %z)
1427         %x : Tensor, %y : Tensor = static_runtime_tests::variadic_outputs(%a)
1428         return (%out)
1429   )IR";
1430   FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor{
1431       {"z", at::randn({1})}, {"x", at::randn({1})}, {"y", at::randn({1})}};
1432   const size_t min_reused_tensors = 1;
1433   testAssignStorageToManagedTensors(
1434       src, std::move(managed_tensor_name_to_tensor), min_reused_tensors);
1435 }
1436 
1437 namespace {
testStaticModuleThrows(const std::string & src,const std::vector<IValue> & args,const std::unordered_map<std::string,IValue> & kwargs)1438 void testStaticModuleThrows(
1439     const std::string& src,
1440     const std::vector<IValue>& args,
1441     const std::unordered_map<std::string, IValue>& kwargs) {
1442   auto static_module = makeStaticModuleFromScript(src);
1443   EXPECT_THROW(static_module(args, kwargs), c10::Error);
1444 }
1445 } // namespace
1446 
TEST(StaticModule,IncorrectTypesPassed)1447 TEST(StaticModule, IncorrectTypesPassed) {
1448   const std::string args_bool_script = R"JIT(
1449     def forward(self, x: bool):
1450         return x
1451   )JIT";
1452   testStaticModuleThrows(args_bool_script, {at::randn({1})}, {});
1453 
1454   const std::string args_tensor_script = R"JIT(
1455     def forward(self, x: Tensor):
1456         return x
1457   )JIT";
1458   testStaticModuleThrows(args_tensor_script, {false}, {});
1459 
1460   const std::string kwargs_int_script = R"JIT(
1461     def forward(self, x: bool = True):
1462         return x
1463   )JIT";
1464   testStaticModuleThrows(kwargs_int_script, {}, {{"x", at::randn({1})}});
1465 
1466   const std::string kwargs_tensor_script = R"JIT(
1467     def forward(self, x: Tensor = torch.randn((1, ))):
1468         return x
1469   )JIT";
1470   testStaticModuleThrows(kwargs_tensor_script, {}, {{"x", 1.0}});
1471 }
1472 
TEST(StaticModule,TooManyArgs)1473 TEST(StaticModule, TooManyArgs) {
1474   const std::string args_src = R"JIT(
1475     def forward(self, x: int):
1476         return x
1477   )JIT";
1478   testStaticModuleThrows(args_src, {0, 1}, {});
1479 
1480   const std::string kwargs_src = R"JIT(
1481     def forward(self, x: int = 1):
1482         return x
1483   )JIT";
1484   testStaticModuleThrows(kwargs_src, {}, {{"y", 0}, {"x", 1}});
1485 }
1486 
TEST(StaticModule,NotEnoughArgs)1487 TEST(StaticModule, NotEnoughArgs) {
1488   const std::string args_src = R"JIT(
1489     def forward(self, x: int):
1490         return x
1491   )JIT";
1492   testStaticModuleThrows(args_src, {}, {});
1493 
1494   const std::string kwargs_src = R"JIT(
1495     def forward(self, *, x: int):
1496         return x
1497   )JIT";
1498   testStaticModuleThrows(kwargs_src, {}, {});
1499 }
1500 
TEST(CreateOwnedRefsForSpecialValues,TopLevel)1501 TEST(CreateOwnedRefsForSpecialValues, TopLevel) {
1502   const auto src = R"IR(
1503     graph():
1504         %c: int = prim::Constant[value=42]()
1505         return (%c)
1506   )IR";
1507 
1508   auto graph = getGraphFromIR(src);
1509   CreateOwnedRefsForSpecialValues(*graph);
1510   EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::create_owned_ref"));
1511 }
1512 
TEST(CreateOwnedRefsForSpecialValues,ValueFromOuterScope)1513 TEST(CreateOwnedRefsForSpecialValues, ValueFromOuterScope) {
1514   const auto src = R"IR(
1515     graph(%cond: bool, %1: int):
1516         %c: int = aten::add(%1, %1)
1517         %x: int = prim::If(%c)
1518           block0():
1519             -> (%c)
1520           block1():
1521             -> (%c)
1522         return (%x)
1523   )IR";
1524 
1525   auto graph = getGraphFromIR(src);
1526   CreateOwnedRefsForSpecialValues(*graph);
1527   EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::create_owned_ref"));
1528 }
1529 
TEST(ForceNonEmptyOutputs,TwoSubBlocks)1530 TEST(ForceNonEmptyOutputs, TwoSubBlocks) {
1531   const auto src = R"IR(
1532     graph(%cond: bool):
1533         %lst : int[] = prim::ListConstruct()
1534         %1 : int = prim::Constant[value=1]()
1535         %2 : int = prim::Constant[value=2]()
1536         prim::If(%cond)
1537           block0():
1538             aten::append(%lst, %1)
1539             -> ()
1540           block1():
1541             aten::append(%lst, %2)
1542             -> ()
1543         return (%lst)
1544   )IR";
1545 
1546   auto graph = getGraphFromIR(src);
1547   ForceNonEmptyOutputs(*graph);
1548 
1549   for (auto* node : graph->nodes()) {
1550     if (node->blocks().empty()) {
1551       continue;
1552     }
1553     EXPECT_EQ(node->outputs().size(), 1);
1554     for (auto* sub_block : node->blocks()) {
1555       EXPECT_EQ(sub_block->outputs().size(), 1);
1556     }
1557   }
1558 }
1559 
TEST(EliminateExtraPermuteOps,FusesSumCorrectly)1560 TEST(EliminateExtraPermuteOps, FusesSumCorrectly) {
1561   const auto src = R"JIT(
1562     def forward(self, x):
1563         y = torch.permute(x, (0, 2, 1))
1564         z = torch.sum(y, dim=-1)
1565         return z
1566   )JIT";
1567   torch::jit::Module mod("m");
1568   mod.define(src);
1569 
1570   auto graph = mod.get_method("forward").graph();
1571   // turn the ListConstruct(%constant) into proper constant lists
1572   ConstantPropagation(graph);
1573   EliminateExtraPermuteOps(graph);
1574 
1575   EXPECT_FALSE(hasNodeWithKind(graph, "aten::permute"));
1576   auto* sum = getNodeWithKind(graph, "aten::sum");
1577   ASSERT_NE(sum, nullptr);
1578   auto dim = toIValue(sum->input(1));
1579   ASSERT_TRUE(dim.has_value() && dim->isIntList());
1580   EXPECT_EQ(dim->toIntList(), c10::List<int64_t>{1});
1581 }
1582 
TEST(EliminateExtraPermuteOps,DoesNotFuseSumWrongDim)1583 TEST(EliminateExtraPermuteOps, DoesNotFuseSumWrongDim) {
1584   const auto src = R"JIT(
1585     def forward(self, x):
1586         y = torch.permute(x, (0, 2, 1))
1587         z = torch.sum(y, dim=1)
1588         return z
1589   )JIT";
1590   torch::jit::Module mod("m");
1591   mod.define(src);
1592 
1593   auto graph = mod.get_method("forward").graph();
1594   // turn the ListConstruct(%constant) into proper constant lists
1595   ConstantPropagation(graph);
1596   EliminateExtraPermuteOps(graph);
1597 
1598   EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1599 }
1600 
TEST(EliminateExtraPermuteOps,DoesNotFuseSumNonConstantDim)1601 TEST(EliminateExtraPermuteOps, DoesNotFuseSumNonConstantDim) {
1602   const auto src = R"JIT(
1603     def forward(self, x, dim: int):
1604         y = torch.permute(x, (0, 2, 1))
1605         z = torch.sum(y, dim=dim)
1606         return z
1607   )JIT";
1608   torch::jit::Module mod("m");
1609   mod.define(src);
1610 
1611   auto graph = mod.get_method("forward").graph();
1612   // turn the ListConstruct(%constant) into proper constant lists
1613   ConstantPropagation(graph);
1614   EliminateExtraPermuteOps(graph);
1615 
1616   EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1617 }
1618 
TEST(EliminateExtraPermuteOps,FusesSoftmaxCorrectly)1619 TEST(EliminateExtraPermuteOps, FusesSoftmaxCorrectly) {
1620   const auto src = R"JIT(
1621     def forward(self, x):
1622         a = torch.permute(x, [0, 2, 1])
1623         b = torch.softmax(a, 2)
1624         c = torch.permute(b, [0, 2, 1])
1625         return c.clone()
1626   )JIT";
1627   torch::jit::Module mod("m");
1628   mod.define(src);
1629   auto graph = mod.get_method("forward").graph();
1630   ConstantPropagation(graph);
1631   EliminateExtraPermuteOps(graph);
1632   graph->dump();
1633 
1634   EXPECT_FALSE(hasNodeWithKind(graph, "aten::permute"));
1635   auto* softmax = getNodeWithKind(graph, "aten::softmax");
1636   ASSERT_NE(softmax, nullptr);
1637   auto dim = toIValue(softmax->input(1));
1638   ASSERT_TRUE(dim.has_value() && dim->isInt());
1639   EXPECT_EQ(dim->toInt(), 1);
1640 
1641   std::vector<IValue> args{at::randn({3, 4, 5})};
1642   testStaticRuntime(src, args, /*args2=*/{}, /*use_allclose=*/true);
1643 }
1644 
TEST(EliminateExtraPermuteOps,DoesNotFuseSoftmaxWrongPermuteDim)1645 TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongPermuteDim) {
1646   const auto src = R"JIT(
1647     def forward(self, x):
1648         a = torch.permute(x, [0, 1, 2])
1649         b = torch.softmax(a, 2)
1650         c = torch.permute(b, [0, 1, 2])
1651         return c.clone()
1652   )JIT";
1653   torch::jit::Module mod("m");
1654   mod.define(src);
1655   auto graph = mod.get_method("forward").graph();
1656   ConstantPropagation(graph);
1657   EliminateExtraPermuteOps(graph);
1658   EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1659 }
1660 
TEST(EliminateExtraPermuteOps,DoesNotFuseSoftmaxWrongSoftmaxDim)1661 TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongSoftmaxDim) {
1662   const auto src = R"JIT(
1663     def forward(self, x):
1664         a = torch.permute(x, [0, 2, 1])
1665         b = torch.softmax(a, 0)
1666         c = torch.permute(b, [0, 2, 1])
1667         return c.clone()
1668   )JIT";
1669   torch::jit::Module mod("m");
1670   mod.define(src);
1671   auto graph = mod.get_method("forward").graph();
1672   ConstantPropagation(graph);
1673   EliminateExtraPermuteOps(graph);
1674   EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1675 }
1676 
TEST(UseSplitAndSqueeze,Fusion)1677 TEST(UseSplitAndSqueeze, Fusion) {
1678   const auto src = R"IR(
1679     graph(%x: Tensor):
1680       %dim: int = prim::Constant[value=1]()
1681       %split_size: int = prim::Constant[value=1]()
1682       %split: Tensor[] = aten::split(%x, %split_size, %dim)
1683       %a: Tensor, %b: Tensor = prim::ListUnpack(%split)
1684       %c: Tensor = aten::squeeze(%a, %dim)
1685       %d: Tensor = aten::squeeze(%b, %dim)
1686       return (%c, %d)
1687   )IR";
1688   auto graph = getGraphFromIR(src);
1689   UseSplitAndSqueeze(graph);
1690   EXPECT_TRUE(
1691       hasNodeWithKind(graph, "static_runtime::fused_split_and_squeeze_copy"));
1692   EXPECT_FALSE(hasNodeWithKind(graph, "aten::split"));
1693   EXPECT_FALSE(hasNodeWithKind(graph, "aten::squeeze"));
1694   EXPECT_FALSE(hasNodeWithKind(graph, "prim::ListUnpack"));
1695 }
1696 
TEST(EliminateNoOpSlice,IntegerStart)1697 TEST(EliminateNoOpSlice, IntegerStart) {
1698   const auto src = R"JIT(
1699     def forward(self, x: List[int]) -> List[int]:
1700         return x[0:]
1701   )JIT";
1702   torch::jit::Module mod("m");
1703   mod.define(src);
1704   auto graph = mod.get_method("forward").graph();
1705   EXPECT_TRUE(hasNodeWithKind(graph, "aten::slice"));
1706   EliminateNoOpSlice(graph);
1707   EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
1708 }
1709 
TEST(EliminateNoOpSlice,NoneStart)1710 TEST(EliminateNoOpSlice, NoneStart) {
1711   const auto src = R"JIT(
1712     def forward(self, x: List[int]) -> List[int]:
1713         return x[:]
1714   )JIT";
1715   torch::jit::Module mod("m");
1716   mod.define(src);
1717   auto graph = mod.get_method("forward").graph();
1718   EliminateNoOpSlice(graph);
1719   EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
1720 }
1721 
1722 #ifdef FBCODE_CAFFE2
1723 // FuseClampNaNToNum pass is disabled externally to avoid MSVC errors in CI
TEST(FuseClampNaNToNum,FusionHappens)1724 TEST(FuseClampNaNToNum, FusionHappens) {
1725   const auto src = R"JIT(
1726     def forward(self, x):
1727         y = torch.clamp(x, min=0.0, max=1.0)
1728         z = y.nan_to_num()
1729         return z.clone()
1730   )JIT";
1731   torch::jit::Module mod("m");
1732   mod.define(src);
1733   auto graph = mod.get_method("forward").graph();
1734   FuseClampNaNToNum(graph);
1735   EXPECT_FALSE(hasNodeWithKind(graph, "aten::clamp"));
1736   EXPECT_FALSE(hasNodeWithKind(graph, "aten::nan_to_num"));
1737   EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
1738   // Correctness of the op is exercised in StaticRuntime.clamp_nan_to_num
1739 }
1740 
TEST(FuseClampNaNToNum,NoFusion)1741 TEST(FuseClampNaNToNum, NoFusion) {
1742   const auto src1 = R"JIT(
1743     def forward(self, x, a: float, b: float):
1744         y = torch.clamp(x, a, b)
1745         z = y.nan_to_num()
1746         return z.clone()
1747   )JIT";
1748 
1749   const auto src2 = R"JIT(
1750     def forward(self, x):
1751         y = torch.clamp(x, min=0.0)
1752         z = y.nan_to_num()
1753         return z.clone()
1754   )JIT";
1755 
1756   const auto src3 = R"JIT(
1757     def forward(self, x):
1758         y = torch.clamp(x, max=0.0)
1759         z = y.nan_to_num()
1760         return z.clone()
1761   )JIT";
1762 
1763   const auto src4 = R"JIT(
1764     def forward(self, x):
1765         y = torch.clamp(x)
1766         z = y.nan_to_num()
1767         return z.clone()
1768   )JIT";
1769 
1770 
1771   auto checkScript = [](const char* src) {
1772     torch::jit::Module mod("m");
1773     mod.define(src);
1774     auto graph = mod.get_method("forward").graph();
1775     FuseClampNaNToNum(graph);
1776     EXPECT_TRUE(hasNodeWithKind(graph, "aten::clamp"));
1777     EXPECT_TRUE(hasNodeWithKind(graph, "aten::nan_to_num"));
1778     EXPECT_FALSE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
1779   };
1780 
1781   checkScript(src1);
1782   checkScript(src2);
1783   checkScript(src3);
1784   checkScript(src4);
1785 }
1786 #endif
1787