1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/ir/irparser.h>
4 #include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
5 #include <torch/csrc/jit/runtime/static/fusion.h>
6 #include <torch/csrc/jit/runtime/static/impl.h>
7 #include <torch/csrc/jit/runtime/static/memory_planner.h>
8 #include <torch/csrc/jit/runtime/static/ops.h>
9 #include <torch/csrc/jit/runtime/static/passes.h>
10 #include <memory>
11
12 #include "deep_wide_pt.h"
13 #include "test_utils.h"
14
15 using namespace torch;
16 using namespace torch::jit;
17 using namespace torch::jit::test;
18
19 C10_DECLARE_bool(static_runtime_disable_debug_memory_overlap_check);
20
21 namespace {
22
makeStaticModuleFromScript(const std::string & script)23 StaticModule makeStaticModuleFromScript(const std::string& script) {
24 Module m("module");
25 m.define(script);
26 return StaticModule(m);
27 }
28
testCanEnableStaticRuntime(const std::string & jit_script)29 bool testCanEnableStaticRuntime(const std::string& jit_script) {
30 script::Module module("module");
31 module.define(jit_script);
32
33 Method method = module.get_method("forward");
34 auto graph = module.get_method("forward").graph();
35
36 // here we do not freeze graph
37 return canEnableStaticRuntime(graph);
38 }
39
testCanEnableStaticRuntimeWithIR(const std::string & ir)40 bool testCanEnableStaticRuntimeWithIR(const std::string& ir) {
41 auto graph = std::make_shared<Graph>();
42 parseIR(ir, graph.get(), {});
43 return canEnableStaticRuntime(graph);
44 }
45
testModuleHasOp(const std::string & jit_script,const char * op_name)46 bool testModuleHasOp(const std::string& jit_script, const char* op_name) {
47 script::Module module("module");
48 module.define(jit_script);
49
50 return forwardHasOp(module, op_name);
51 }
52
53 const auto reshape_inplace_script = R"JIT(
54 def forward(self, inp: Tensor, shape: List[int]):
55 a = inp + inp
56 b = a.reshape(shape)
57 c = b.sigmoid_()
58 d = c + c
59 e = a + a
60 f = b + b
61 return (d, e, f)
62 )JIT";
63
64 const auto reshape_inplace_script_1 = R"JIT(
65 def forward(self, inp: Tensor, shape: List[int], flag: bool):
66 if flag:
67 a = inp + inp
68 b = a.reshape(shape)
69 c = b.sigmoid()
70 else:
71 a = inp * inp
72 b = a.sigmoid_()
73 c = b.reshape(shape)
74 d = c + c
75 e = a + a
76 f = b + b
77 return (d, e, f)
78 )JIT";
79
80 const auto sigmoid_inplace_script = R"JIT(
81 def forward(self, inp: Tensor):
82 a = torch.sigmoid(inp, out=inp).clone()
83 return (a)
84 )JIT";
85
86 } // namespace
87
88 // Test that StaticModule::value_group groups values of the graph into
89 // 1) Inputs/Constants and their aliases 2) Outputs and their aliases.
TEST(StaticModule,ValueGroup)90 TEST(StaticModule, ValueGroup) {
91 const std::string src = R"IR(
92 graph(%input0 : Tensor, %input1 : Tensor):
93 # Constants.
94 %0 : int = prim::Constant[value=1]()
95 # Internal values.
96 %1 : Tensor = aten::add(%input0, %input1, %0)
97 # This includes aliases of output.
98 %2 : Tensor = aten::add(%input0, %1, %0)
99 # This includes output.
100 %3 : (Tensor) = prim::TupleConstruct(%2)
101 return (%3)
102 )IR";
103 auto input_graph = std::make_shared<torch::jit::Graph>();
104 torch::jit::parseIR(src, input_graph.get());
105 torch::jit::StaticModule sm(input_graph);
106 const Graph& graph = sm.graph();
107 std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
108 auto* root_block = sm.root_block();
109 const auto& value_group = sm.block_info(root_block).value_group();
110
111 std::vector<const Value*> expected_input_aliases{
112 graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
113 for (auto* value : expected_input_aliases) {
114 EXPECT_TRUE(value_group.isExternalAlias(value));
115 }
116
117 std::vector<const Value*> expected_output_aliases{
118 graph.outputs()[0], nodes[2]->output()};
119 for (auto* value : expected_output_aliases) {
120 EXPECT_TRUE(value_group.isOutputAlias(value));
121 }
122 EXPECT_FALSE(value_group.isAlwaysAlive(nodes[1]->output()));
123 EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[0]));
124 EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[1]));
125 EXPECT_TRUE(value_group.isAlwaysAlive(graph.outputs()[0]));
126 }
127
TEST(StaticModule,IsOptimizableContainerType_NonOptimizableInputs)128 TEST(StaticModule, IsOptimizableContainerType_NonOptimizableInputs) {
129 // Cannot use out variants for list/tuple construction here because
130 // inputs are not produced by nodes with out variants.
131 const std::string src = R"JIT(
132 def forward(self, a, b):
133 a_alias = a.view(a.size())
134 non_optimizable_list = [a_alias]
135 non_optimizable_tuple = (b, )
136 return non_optimizable_list, non_optimizable_tuple
137 )JIT";
138
139 auto sm = makeStaticModuleFromScript(src);
140 const auto& graph = sm.graph();
141 auto* root_block = sm.root_block();
142 const auto& block_info = sm.block_info(root_block);
143
144 for (const Node* n : graph.nodes()) {
145 EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
146 }
147 }
148
TEST(StaticModule,IsOptimizableContainerType_WrongType)149 TEST(StaticModule, IsOptimizableContainerType_WrongType) {
150 // Cannot use out variants for list/tuple construction here because
151 // types are not Tensors
152 const std::string src = R"JIT(
153 def forward(self, x: int, y: int):
154 a = 1 + x
155 b = 2 + y
156 non_optimizable_list = [a]
157 non_optimizable_tuple = (b, )
158 return non_optimizable_list, non_optimizable_tuple
159 )JIT";
160
161 auto sm = makeStaticModuleFromScript(src);
162 const auto& graph = sm.graph();
163 auto* root_block = sm.root_block();
164 const auto& block_info = sm.block_info(root_block);
165
166 for (const Node* n : graph.nodes()) {
167 EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
168 }
169 }
170
TEST(StaticModule,IsOptimizableContainerType_CanUseOutVariant)171 TEST(StaticModule, IsOptimizableContainerType_CanUseOutVariant) {
172 // This container should be optimizable since aten::add has an
173 // out variant the container contains Tensors.
174 const std::string src = R"JIT(
175 def forward(self, x):
176 a = torch.relu(x)
177 optimizable_list = [a]
178 return optimizable_list
179 )JIT";
180 auto sm = makeStaticModuleFromScript(src);
181 const auto& graph = sm.graph();
182 auto* root_block = sm.root_block();
183 const auto& block_info = sm.block_info(root_block);
184
185 for (const Node* n : graph.nodes()) {
186 if (n->kind() == c10::prim::ListConstruct) {
187 EXPECT_TRUE(block_info.node_is_optimizable_container_type(n));
188 } else {
189 EXPECT_FALSE(block_info.node_is_optimizable_container_type(n));
190 }
191 }
192 }
193
194 // Test operator() with rvalue inputs
TEST(StaticModule,RValueInputs)195 TEST(StaticModule, RValueInputs) {
196 const std::string src = R"JIT(
197 def forward(self, x):
198 y = torch.relu(x)
199 return y.clone()
200 )JIT";
201 auto sm = makeStaticModuleFromScript(src);
202
203 std::vector<IValue> input{at::randn({1})};
204
205 auto expected = sm(input, {});
206 auto actual = sm(std::move(input), {});
207
208 EXPECT_TRUE(expected.isTensor());
209 EXPECT_TRUE(actual.isTensor());
210 EXPECT_TRUE(expected.toTensor().equal(actual.toTensor()));
211 }
212
TEST(StaticRuntime,ModuleHasOp)213 TEST(StaticRuntime, ModuleHasOp) {
214 EXPECT_TRUE(testModuleHasOp(reshape_inplace_script, "aten::sigmoid_"));
215 EXPECT_TRUE(testModuleHasOp(reshape_inplace_script_1, "aten::reshape"));
216 EXPECT_TRUE(testModuleHasOp(sigmoid_inplace_script, "aten::clone"));
217 EXPECT_FALSE(testModuleHasOp(reshape_inplace_script_1, "aten::add_"));
218 }
219
TEST(StaticRuntime,ReplaceWithCopy_replaces_reshape)220 TEST(StaticRuntime, ReplaceWithCopy_replaces_reshape) {
221 auto ExpectToReplaceWithCopy = [](const std::string& jit_script) {
222 auto graph = getGraphFromScript(jit_script);
223 EXPECT_TRUE(graphHasOp(graph, "aten::reshape"));
224 EXPECT_FALSE(graphHasOp(graph, "static_runtime::reshape_copy"));
225
226 ReplaceWithCopy(graph);
227
228 // aten::reshape -> static_runtime::reshape_copy
229 EXPECT_FALSE(graphHasOp(graph, "aten::reshape"));
230 EXPECT_TRUE(graphHasOp(graph, "static_runtime::reshape_copy"));
231 };
232
233 ExpectToReplaceWithCopy(R"JIT(
234 def forward(self, inp: Tensor, shape: List[int]):
235 a = inp.reshape(shape)
236 return (a)
237 )JIT");
238 ExpectToReplaceWithCopy(R"JIT(
239 def forward(self, inp: Tensor, shape: List[int]):
240 a = inp * 2
241 b = inp * 3
242 c = inp.reshape(shape)
243 return (a, b, c)
244 )JIT");
245 ExpectToReplaceWithCopy(R"JIT(
246 def forward(self, cond: bool, x):
247 if cond:
248 y = x.reshape(x.shape)
249 else:
250 y = x.clone()
251 return y.clone()
252 )JIT");
253 }
254
TEST(StaticRuntime,ReplaceWithCopy_does_not_replace_reshape_if_input_has_writters)255 TEST(
256 StaticRuntime,
257 ReplaceWithCopy_does_not_replace_reshape_if_input_has_writters) {
258 auto ExpectNotToReplaceWithCopy = [](const std::string& jit_script) {
259 auto graph = getGraphFromScript(jit_script);
260 EXPECT_TRUE(graphHasOp(graph, "aten::reshape"));
261 EXPECT_FALSE(graphHasOp(graph, "static_runtime::reshape_copy"));
262
263 ReplaceWithCopy(graph);
264
265 // No Replacement
266 EXPECT_TRUE(graphHasOp(graph, "aten::reshape"));
267 EXPECT_FALSE(graphHasOp(graph, "static_runtime::reshape_copy"));
268 };
269
270 ExpectNotToReplaceWithCopy(R"JIT(
271 def forward(self, inp: Tensor, shape: List[int]):
272 a = inp.reshape(shape)
273 inp *= 2
274 return (a)
275 )JIT");
276 ExpectNotToReplaceWithCopy(R"JIT(
277 def forward(self, inp: Tensor, shape: List[int]):
278 a = inp.reshape(shape)
279 a *= 2
280 return (a)
281 )JIT");
282 ExpectNotToReplaceWithCopy(R"JIT(
283 def forward(self, inp: Tensor, inp2: Tensor, shape: List[int]):
284 a = inp.reshape(shape)
285 a *= 2
286 b = a.reshape(shape)
287 return (b)
288 )JIT");
289 ExpectNotToReplaceWithCopy(R"JIT(
290 def forward(self, inp: Tensor, shape: List[int]):
291 a = inp.reshape(shape)
292 b = a.reshape(shape)
293 c = b.reshape(shape)
294 d = c.reshape(shape)
295 e = b.sigmoid_()
296 return (d)
297 )JIT");
298 ExpectNotToReplaceWithCopy(reshape_inplace_script);
299 }
300
TEST(StaticRuntime,CanEnableStaticRuntime)301 TEST(StaticRuntime, CanEnableStaticRuntime) {
302 const auto while_script = R"JIT(
303 def forward(self, a: Tensor, x: int):
304 c = 0
305 while c < x:
306 a = a * a
307 c += 2
308 return a
309 )JIT";
310
311 const auto for_script = R"JIT(
312 def forward(self, a: Tensor, x: int):
313 for c in range(x):
314 a = a * a
315 return a
316 )JIT";
317
318 const auto if_script = R"JIT(
319 def forward(self, a: Tensor, b: bool):
320 if b:
321 return a
322 else:
323 return a * a
324 )JIT";
325
326 const auto is_script_tensors = R"JIT(
327 def forward(self, a: Tensor, b: Tensor):
328 return a is b
329 )JIT";
330
331 const auto is_script_none = R"JIT(
332 def forward(self, a: Optional[Tensor]):
333 return a is None
334 )JIT";
335
336 const auto is_not_script_tensors = R"JIT(
337 def forward(self, a: Tensor, b: Tensor):
338 return a is not b
339 )JIT";
340
341 const auto is_not_script_none = R"JIT(
342 def forward(self, a: Optional[Tensor]):
343 return a is not None
344 )JIT";
345
346 EXPECT_TRUE(testCanEnableStaticRuntime(reshape_inplace_script));
347 EXPECT_TRUE(testCanEnableStaticRuntime(for_script));
348 EXPECT_TRUE(testCanEnableStaticRuntime(while_script));
349 EXPECT_TRUE(testCanEnableStaticRuntime(if_script));
350 EXPECT_FALSE(testCanEnableStaticRuntime(is_script_tensors));
351 EXPECT_TRUE(testCanEnableStaticRuntime(is_script_none));
352 EXPECT_FALSE(testCanEnableStaticRuntime(is_not_script_tensors));
353 EXPECT_TRUE(testCanEnableStaticRuntime(is_not_script_none));
354
355 }
TEST(StaticRuntime,CanEnableStaticRuntimeCallMethod)356 TEST(StaticRuntime, CanEnableStaticRuntimeCallMethod) {
357 const auto call_method = R"IR(
358 graph(%x : Tensor):
359 %1 : Tensor = prim::CallMethod[name="offsets"](%x)
360 return (%1)
361 )IR";
362 EXPECT_FALSE(testCanEnableStaticRuntimeWithIR(call_method));
363 }
364
TEST(StaticRuntime,CanEnableStaticRuntimeSubBlocks)365 TEST(StaticRuntime, CanEnableStaticRuntimeSubBlocks) {
366 const auto src = R"JIT(
367 def forward(self, a: Tensor, b: Tensor, cond: bool):
368 if cond:
369 # aten::__is__ on tensors is blocked
370 return a is b
371 return False
372 )JIT";
373
374 EXPECT_FALSE(testCanEnableStaticRuntime(src));
375 }
376
TEST(StaticRuntime,NestedOutput)377 TEST(StaticRuntime, NestedOutput) {
378 // dict of tuple of list
379 const auto nested_output_script_0 = R"JIT(
380 def forward(self, a, b):
381 c = (a + b).relu().nan_to_num().float()
382 d = a.flatten().nan_to_num() * b.flatten().nan_to_num()
383 e = d.float().relu()
384 f = ([c], [d])
385 g = ([e], [f])
386 return ({"prediction":(f, g)})
387 )JIT";
388
389 // tuple of lists
390 const auto nested_output_script_1 = R"JIT(
391 def forward(self, a, b):
392 c = (a + b).relu().nan_to_num().float()
393 d = a.flatten().nan_to_num() * b.flatten().nan_to_num()
394 e = d.float().relu()
395 f = [c]
396 g = [e]
397 return (f, g)
398 )JIT";
399
400 // list of tuple of dict
401 const auto nested_output_script_2 = R"JIT(
402 def forward(self, a, b):
403 c = (a + b).relu().nan_to_num().float()
404 d = b * c
405 e = a.flatten().nan_to_num() * b.flatten().nan_to_num()
406 f = e.float().relu()
407 g = ({"d": d}, {"b": b})
408 h = ({"e": e}, {"f": f})
409 return [g, h]
410 )JIT";
411
412 // lit of dict
413 const auto nested_output_script_3 = R"JIT(
414 def forward(self, a, b):
415 c = (a + b).relu().nan_to_num().float()
416 d = b * c
417 e = a.flatten().nan_to_num() * b.flatten().nan_to_num()
418 f = e.float().relu()
419 g = {"d": d, "b": b}
420 h = {"e": e, "f": f}
421 return [g, h]
422 )JIT";
423
424 auto run_test = [&](std::vector<int64_t> shapes) {
425 auto a = at::randn(shapes);
426 auto b = at::randn(shapes);
427
428 std::vector<IValue> args{a, b};
429 testStaticRuntime(nested_output_script_0, args);
430 testStaticRuntime(nested_output_script_1, args);
431 testStaticRuntime(nested_output_script_2, args);
432 testStaticRuntime(nested_output_script_3, args);
433
434 if (shapes.size() > 0 && shapes[0] != 0) {
435 shapes[0] *= 3;
436 testStaticRuntime(
437 nested_output_script_0, args, {at::randn(shapes), at::randn(shapes)});
438 testStaticRuntime(
439 nested_output_script_1, args, {at::randn(shapes), at::randn(shapes)});
440 }
441 };
442 run_test({2, 3, 1, 2});
443 run_test({2, 6});
444 }
445
446 // test memory reuse
TEST(StaticRuntime,LongModel)447 TEST(StaticRuntime, LongModel) {
448 torch::jit::Module mod = getLongScriptModel();
449 auto a = torch::randn({2, 2});
450 auto b = torch::randn({2, 2});
451 auto c = torch::randn({2, 2});
452
453 // run jit graph executor
454 std::vector<at::IValue> input_ivalues({a, b, c});
455 at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
456
457 // run static runtime
458 std::vector<c10::IValue> input_tensors({a, b, c});
459 torch::jit::StaticModule smod(mod);
460 at::Tensor output_2 = smod(input_tensors, {}).toTensor();
461 smod.runtime().check_for_memory_leak();
462 EXPECT_TRUE(
463 torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7));
464 }
465
TEST(StaticRuntime,TrivialModel)466 TEST(StaticRuntime, TrivialModel) {
467 torch::jit::Module mod = getTrivialScriptModel();
468 auto a = torch::randn({2, 2});
469 auto b = torch::randn({2, 2});
470 auto c = torch::randn({2, 2});
471
472 // run jit graph executor
473 std::vector<at::IValue> input_ivalues({a, b, c});
474 at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
475
476 // run static runtime
477 std::vector<c10::IValue> input_tensors({a, b, c});
478 torch::jit::StaticModule smod(mod);
479 at::Tensor output_2 = smod(input_tensors, {}).toTensor();
480 smod.runtime().check_for_memory_leak();
481 EXPECT_TRUE(
482 torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-7));
483 }
484
TEST(StaticRuntime,DeepWide)485 TEST(StaticRuntime, DeepWide) {
486 const int embedding_size = 32;
487 const int num_features = 50;
488 torch::jit::Module mod = getDeepAndWideSciptModel();
489 torch::jit::StaticModule smod(mod);
490
491 for (int batch_size : {1, 8, 32}) {
492 for (int i = 0; i < 2; ++i) {
493 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
494 auto user_emb = torch::randn({batch_size, 1, embedding_size});
495 auto wide = torch::randn({batch_size, num_features});
496
497 // run jit graph executor
498 std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
499 auto output_1 = getTensor(mod.forward(inputs));
500
501 // run static runtime
502 std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
503 auto outputs = smod(input_tensors, {}).toTupleRef().elements();
504 ASSERT_TRUE(outputs.size() > 0);
505 at::Tensor output_2 = outputs[0].toTensor();
506 smod.runtime().check_for_memory_leak();
507 EXPECT_TRUE(
508 torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
509 }
510 }
511 }
512
TEST(StaticRuntime,KWargsAPI_1)513 TEST(StaticRuntime, KWargsAPI_1) {
514 const int embedding_size = 32;
515 const int num_features = 50;
516 auto module = getDeepAndWideSciptModel();
517 torch::jit::StaticModule smod(module);
518
519 for (int batch_size : {1, 8, 32}) {
520 for (int i = 0; i < 2; ++i) {
521 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
522 auto user_emb = torch::randn({batch_size, 1, embedding_size});
523 auto wide = torch::randn({batch_size, num_features});
524 {
525 std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
526
527 // run jit graph executor
528 at::Tensor output_1 = getTensor(module.forward(inputs));
529
530 // run static runtime
531 c10::IValue output_ivalue = smod(inputs, {});
532 smod.runtime().check_for_memory_leak();
533
534 at::Tensor output_2 = getTensor(output_ivalue);
535 EXPECT_TRUE(
536 torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
537
538 // check for output aliasing
539 EXPECT_EQ(output_ivalue.use_count(), 1);
540 output_ivalue = IValue();
541
542 EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
543 }
544
545 // check for input aliasing (deep & wide does not have ops
546 // that create aliases of input tensors)
547 EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
548 EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
549 EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
550 }
551 }
552 }
553
TEST(StaticRuntime,KWargsAPI_2)554 TEST(StaticRuntime, KWargsAPI_2) {
555 const int embedding_size = 32;
556 const int num_features = 50;
557 auto module = getDeepAndWideSciptModel();
558 torch::jit::StaticModule smod(module);
559
560 for (int batch_size : {1, 8, 32}) {
561 for (int i = 0; i < 2; ++i) {
562 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
563 auto user_emb = torch::randn({batch_size, 1, embedding_size});
564 auto wide = torch::randn({batch_size, num_features});
565 {
566 // run jit graph executor
567 std::vector<at::IValue> args({ad_emb_packed, user_emb, wide});
568 at::Tensor output_1 = getTensor(module.forward(args));
569
570 std::unordered_map<std::string, c10::IValue> kwargs(
571 {{"ad_emb_packed", ad_emb_packed},
572 {"user_emb", user_emb},
573 {"wide", wide}});
574
575 // run static runtime
576 c10::IValue output_ivalue = smod(std::vector<IValue>{}, kwargs);
577 smod.runtime().check_for_memory_leak();
578
579 at::Tensor output_2 = getTensor(output_ivalue);
580 EXPECT_TRUE(
581 torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
582
583 // check for output aliasing
584 EXPECT_EQ(output_ivalue.use_count(), 1);
585 output_ivalue = IValue();
586
587 EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
588 }
589
590 EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
591 EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
592 EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
593 }
594 }
595 }
596
TEST(StaticRuntime,KWargsAPI_Optional)597 TEST(StaticRuntime, KWargsAPI_Optional) {
598 const auto src = R"JIT(
599 def forward(self, x, y, z: Optional[Tensor] = None):
600 return x + y
601 )JIT";
602
603 torch::jit::Module mod("mod");
604 mod.define(src);
605 torch::jit::StaticModule smod(mod);
606 const auto kwargs = std::unordered_map<std::string, IValue>{
607 {"x", at::randn({1})}, {"y", at::randn({1})}};
608
609 auto expected = mod.forward({}, kwargs).toTensor();
610 auto actual = smod({}, kwargs).toTensor();
611
612 EXPECT_TRUE(expected.equal(actual));
613 }
614
TEST(StaticRuntime,CleanUpMemory)615 TEST(StaticRuntime, CleanUpMemory) {
616 const int embedding_size = 32;
617 const int num_features = 50;
618 torch::jit::Module mod = getDeepAndWideSciptModel();
619
620 for (auto enable_out_variant : {true, false}) {
621 for (auto optimize_memory : {true, false}) {
622 for (auto manage_output_tensors : {true, false}) {
623 if (manage_output_tensors && !enable_out_variant) {
624 // when manage_output_tensors is enabled, enable_out_variant
625 // must be enabled too
626 continue;
627 }
628 if (optimize_memory && !enable_out_variant) {
629 // when optimize_memory is enabled, enable_out_variant must be
630 // enabled too
631 continue;
632 }
633 VLOG(1) << "enable_out_variant: " << enable_out_variant
634 << ", optimize_memory: " << optimize_memory
635 << ", manage_output_tensors: " << manage_output_tensors;
636 torch::jit::StaticModuleOptions opts{
637 enable_out_variant, optimize_memory, manage_output_tensors};
638 torch::jit::StaticModule smod(mod, false, opts);
639 torch::jit::StaticRuntime runtime(smod);
640
641 for (int batch_size : {1, 8, 32}) {
642 for (int i = 0; i < 2; ++i) {
643 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
644 auto user_emb = torch::randn({batch_size, 1, embedding_size});
645 auto wide = torch::randn({batch_size, num_features});
646
647 // run jit graph executor
648 std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
649 auto output_1 = getTensor(mod.forward(inputs));
650
651 // run static runtime
652 std::vector<c10::IValue> input_tensors(
653 {ad_emb_packed, user_emb, wide});
654 auto outputs = runtime(input_tensors, {}).toTupleRef().elements();
655 ASSERT_TRUE(outputs.size() > 0);
656 auto output_2 = outputs[0].toTensor();
657 runtime.check_for_memory_leak();
658 EXPECT_TRUE(torch::allclose(
659 output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
660 if (manage_output_tensors) {
661 runtime.deallocateOutputTensors();
662 runtime.checkOutputTensorMemoryLeaks();
663 }
664 }
665 }
666 }
667 }
668 }
669 }
670
TEST(StaticRuntime,ManageOutputTensors)671 TEST(StaticRuntime, ManageOutputTensors) {
672 const std::string test_graph = R"IR(
673 graph(%0 : Tensor):
674 # With manage_output_tensor enabled, this tensor is managed.
675 %1 : Tensor = aten::abs(%0)
676 # The output container object is never managed.
677 %2 : (Tensor) = prim::TupleConstruct(%1)
678 return (%2)
679 )IR";
680 auto a = at::randn({2, 2});
681 auto b = at::randn({3, 6});
682 std::vector<at::IValue> args{a};
683 std::vector<at::IValue> args2{b};
684 testStaticRuntime(test_graph, args);
685 testStaticRuntime(test_graph, args, args2);
686 }
687
TEST(StaticRuntime,ManageOutputTensorsReturnsOutputContainingManagedOutputTensor)688 TEST(
689 StaticRuntime,
690 ManageOutputTensorsReturnsOutputContainingManagedOutputTensor) {
691 const std::string test_graph = R"IR(
692 graph(%0 : Tensor):
693 # With manage_output_tensor enabled, this tensor is managed.
694 %1 : Tensor = aten::abs(%0)
695 # The output container object is never managed.
696 %2 : (Tensor) = prim::TupleConstruct(%1)
697 return (%2)
698 )IR";
699 auto g = std::make_shared<torch::jit::Graph>();
700 torch::jit::parseIR(test_graph, g.get());
701 torch::jit::StaticModuleOptions opts{
702 /*enable_out_variant=*/true,
703 /*optimize_memory=*/true,
704 /*manage_output_tensors=*/true};
705 auto a = at::randn({2, 2});
706 std::vector<at::IValue> args{a};
707 torch::jit::StaticModule smod(g, opts);
708 torch::jit::StaticRuntime runtime(smod);
709 // Profile run.
710 {
711 IValue tuple = runtime(args, {});
712 ASSERT_TRUE(tuple.isTuple());
713 ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
714 // Do not manage input value.
715 EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
716 // Do not manage direct output value.
717 EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
718 IValue element = tuple.toTupleRef().elements()[0];
719 // Tensor to be managed, but not yet from the profile run.
720 EXPECT_FALSE(runtime.isManagedOutputTensor(element));
721 tuple = IValue();
722 runtime.deallocateOutputTensors();
723 runtime.checkOutputTensorMemoryLeaks();
724 }
725 // Second run that manages output tensors.
726 {
727 IValue tuple = runtime(args, {});
728 ASSERT_TRUE(tuple.isTuple());
729 ASSERT_EQ(tuple.toTupleRef().elements().size(), 1);
730 // Do not manage input value.
731 EXPECT_FALSE(runtime.isManagedOutputTensor(args[0]));
732 // Do not manage direct output value.
733 EXPECT_FALSE(runtime.isManagedOutputTensor(tuple));
734 IValue element = tuple.toTupleRef().elements()[0];
735 // Tensor to be managed, but not yet from the profile run.
736 EXPECT_TRUE(runtime.isManagedOutputTensor(element));
737 tuple = IValue();
738 runtime.deallocateOutputTensors();
739 runtime.checkOutputTensorMemoryLeaks();
740 }
741 }
742
TEST(StaticRuntime,ManageOutputTensorsWithDeallocateOutputTensors)743 TEST(StaticRuntime, ManageOutputTensorsWithDeallocateOutputTensors) {
744 const int embedding_size = 32;
745 const int num_features = 50;
746 torch::jit::Module mod = getDeepAndWideSciptModel();
747
748 torch::jit::StaticModuleOptions opts{
749 /*enable_out_variant=*/true,
750 /*optimize_memory=*/true,
751 /*manage_output_tensors=*/true};
752 torch::jit::StaticModule smod(mod, false, opts);
753 torch::jit::StaticRuntime runtime(smod);
754 // Reenter the runtime with the input with the same shape/different shapes.
755 for (int batch_size : {8, 8, 24, 8}) {
756 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
757 auto user_emb = torch::randn({batch_size, 1, embedding_size});
758 auto wide = torch::randn({batch_size, num_features});
759 std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
760 runtime(input_tensors, {});
761 runtime.check_for_memory_leak();
762 runtime.deallocateOutputTensors();
763 runtime.checkOutputTensorMemoryLeaks();
764 }
765 }
766
TEST(StaticRuntime,ManageOutputTensorsWithoutDeallocateOutputTensors)767 TEST(StaticRuntime, ManageOutputTensorsWithoutDeallocateOutputTensors) {
768 const int embedding_size = 32;
769 const int num_features = 50;
770 torch::jit::Module mod = getDeepAndWideSciptModel();
771
772 torch::jit::StaticModuleOptions opts{
773 /*enable_out_variant=*/true,
774 /*optimize_memory=*/true,
775 /*manage_output_tensors=*/true};
776 torch::jit::StaticModule smod(mod, false, opts);
777 torch::jit::StaticRuntime runtime(smod);
778 int batch_size = 8;
779 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
780 auto user_emb = torch::randn({batch_size, 1, embedding_size});
781 auto wide = torch::randn({batch_size, num_features});
782 std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
783 // Profile run.
784 runtime(input_tensors, {});
785 runtime.deallocateOutputTensors();
786 // Run again to allocate output Tensors without deallocating them.
787 runtime(input_tensors, {});
788 // Memory leak checking fails.
789 EXPECT_THROW(runtime.checkOutputTensorMemoryLeaks(), std::exception);
790 // Calling the runtime without deallocation fails too.
791 EXPECT_THROW(runtime(input_tensors, {}), std::exception);
792 // After deallocation, everything works fine.
793 runtime.deallocateOutputTensors();
794 runtime.checkOutputTensorMemoryLeaks();
795 runtime(input_tensors, {});
796 }
797
TEST(StaticRuntime,DisableManageOutputTensors)798 TEST(StaticRuntime, DisableManageOutputTensors) {
799 const std::string test_graph = R"IR(
800 graph(%0 : Tensor):
801 # With manage_output_tensor enabled, this tensor is managed.
802 %1 : Tensor = aten::abs(%0)
803 # The output container object is never managed.
804 %2 : (Tensor) = prim::TupleConstruct(%1)
805 return (%2)
806 )IR";
807 auto g = std::make_shared<torch::jit::Graph>();
808 torch::jit::parseIR(test_graph, g.get());
809 torch::jit::StaticModuleOptions opts{
810 /*enable_out_variant=*/true,
811 /*optimize_memory=*/true,
812 /*manage_output_tensors=*/true};
813 auto a = at::randn({2, 2});
814 std::vector<at::IValue> args{a};
815 torch::jit::StaticModule smod(g, opts);
816 torch::jit::StaticRuntime runtime(smod);
817 // Profile run.
818 {
819 IValue tuple = runtime(args, {});
820 IValue element = tuple.toTupleRef().elements()[0];
821 EXPECT_FALSE(runtime.isManagedOutputTensor(element));
822 tuple = IValue();
823 runtime.deallocateOutputTensors();
824 runtime.checkOutputTensorMemoryLeaks();
825 }
826 // Second run that manages output tensors.
827 {
828 IValue tuple = runtime(args, {});
829 IValue element = tuple.toTupleRef().elements()[0];
830 EXPECT_TRUE(runtime.isManagedOutputTensor(element));
831 tuple = IValue();
832 runtime.deallocateOutputTensors();
833 runtime.checkOutputTensorMemoryLeaks();
834 }
835
836 // Reset the runtime and start profiling again.
837 runtime.disableManageOutputTensors();
838
839 IValue copied_output_tensor;
840 IValue original_output_tensor;
841 // New profile run.
842 {
843 IValue tuple = runtime(args, {});
844 IValue element = tuple.toTupleRef().elements()[0];
845 EXPECT_FALSE(runtime.isManagedOutputTensor(element));
846 copied_output_tensor = element.deepcopy();
847 original_output_tensor = element;
848 tuple = IValue();
849 // No-op since manage_output_tensor is disabled now.
850 runtime.deallocateOutputTensors();
851 runtime.checkOutputTensorMemoryLeaks();
852 }
853 // Ensure that `original_output_tensor` is no longer managed: even after
854 // calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
855 // contains a valid value.
856 EXPECT_TRUE(
857 original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
858
859 // Ensure that the second optimized run does not manage the output tensor
860 // either.
861 {
862 IValue tuple = runtime(args, {});
863 IValue element = tuple.toTupleRef().elements()[0];
864 EXPECT_FALSE(runtime.isManagedOutputTensor(element));
865 copied_output_tensor = element.deepcopy();
866 original_output_tensor = element;
867 tuple = IValue();
868 // No-op since manage_output_tensor is disabled now.
869 runtime.deallocateOutputTensors();
870 runtime.checkOutputTensorMemoryLeaks();
871 }
872 // Ensure that `original_output_tensor` is no longer managed: even after
873 // calling `runtime.deallocateOutputTensors();` `original_output_tensor` still
874 // contains a valid value.
875 EXPECT_TRUE(
876 original_output_tensor.toTensor().equal(copied_output_tensor.toTensor()));
877 }
878
TEST(StaticRuntime,FusionPass)879 TEST(StaticRuntime, FusionPass) {
880 const int embedding_size = 32;
881 const int num_features = 50;
882 for (int batch_size : {1, 8, 32}) {
883 for (int i = 0; i < 2; ++i) {
884 torch::jit::Module module = getDeepAndWideSciptModel();
885 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
886 auto user_emb = torch::randn({batch_size, 1, embedding_size});
887 auto wide = torch::randn({batch_size, num_features});
888
889 // run jit graph executor
890 std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
891 auto output_1 = getTensor(module.forward(inputs));
892
893 Method method = module.get_method("forward");
894 auto graph = method.graph();
895 fuseStaticSubgraphs(graph, 2);
896 bool hit = false;
897 for (const auto& n : module.get_method("forward").graph()->nodes()) {
898 if (n->kind() == torch::jit::prim::StaticSubgraph) {
899 hit = true;
900 }
901 }
902 EXPECT_TRUE(hit);
903 auto output_2 = getTensor(module.forward(inputs));
904 EXPECT_TRUE(
905 torch::allclose(output_1, output_2, /*rtol=*/1e-5, /*atol=*/1e-5));
906 }
907 }
908 }
909
createProcessedNodeInputs(c10::ArrayRef<uint16_t> inputs)910 static ProcessedNodeInputs createProcessedNodeInputs(
911 c10::ArrayRef<uint16_t> inputs) {
912 ProcessedNodeInputs result(inputs.size());
913 for (const auto idx : c10::irange(inputs.size())) {
914 result[idx] = inputs[idx];
915 }
916 return result;
917 }
918
TEST(ProcessedNode,VerifyNoMemoryOverlapWithImmutableInputsWithImmutableArguments)919 TEST(
920 ProcessedNode,
921 VerifyNoMemoryOverlapWithImmutableInputsWithImmutableArguments) {
922 const auto sigmoid_script = R"JIT(
923 def forward(self, inp: Tensor):
924 b = torch.sigmoid(inp).clone()
925 return (b)
926 )JIT";
927 script::Module module("module");
928 // Not using out= variant.
929 module.define(sigmoid_script);
930 torch::jit::StaticModule smodule(module);
931 Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
932 std::array<IValue, 2> values = {torch::randn({2, 3}), torch::randn({3, 1})};
933 ProcessedFunction fn(
934 sigmoid_node,
935 /*enable_out_variant=*/true,
936 /*check_memory_overlap=*/false);
937 StaticNodeInfo static_node_info(
938 sigmoid_node, &fn, createProcessedNodeInputs({0}), 1);
939 ProcessedNode pnode(static_node_info, values.data());
940 EXPECT_TRUE(pnode.verify_no_memory_overlap(/* force_check*/ true));
941
942 pnode.Output(0) = values[0];
943 EXPECT_FALSE(pnode.verify_no_memory_overlap(/* force_check*/ true));
944 }
945
TEST(ProcessedNode,VerifyNoMemoryOverlapWithImmutableInputsWithInplaceOps)946 TEST(ProcessedNode, VerifyNoMemoryOverlapWithImmutableInputsWithInplaceOps) {
947 script::Module module("module");
948 // Using out= variant.
949 module.define(sigmoid_inplace_script);
950 torch::jit::StaticModule smodule(module);
951 Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
952 std::array<IValue, 2> values = {torch::randn({2, 3}), torch::randn({3, 1})};
953 ProcessedFunction fn(
954 sigmoid_node,
955 /*enable_out_variant=*/true,
956 /*check_memory_overlap=*/false);
957 StaticNodeInfo static_node_info(
958 sigmoid_node, &fn, createProcessedNodeInputs({0}), 1);
959 ProcessedNode pnode(static_node_info, values.data());
960
961 ASSERT_EQ(&pnode.Output(0), &values[1]);
962 EXPECT_TRUE(pnode.verify_no_memory_overlap());
963
964 pnode.Output(0) = values[0];
965 EXPECT_TRUE(pnode.verify_no_memory_overlap());
966 }
967
TEST(ProcessedNode,VerifyNoMemoryOverlapWithOverlappingOutputs)968 TEST(ProcessedNode, VerifyNoMemoryOverlapWithOverlappingOutputs) {
969 auto g = std::make_shared<torch::jit::Graph>();
970 torch::jit::parseIR(
971 R"IR(
972 graph(%0):
973 %1 : Tensor, %2 : Tensor = prim::ListUnpack(%0)
974 return (%1, %2))IR",
975 g.get());
976 torch::jit::StaticModule smodule(g);
977 Node* list_unpack_node = getNodeWithKind(smodule, "prim::ListUnpack");
978 {
979 std::array<IValue, 3> values = {
980 at::randn({2, 3}), at::empty({1, 3}), at::empty({4, 5})};
981 ProcessedFunction fn(
982 list_unpack_node,
983 /*enable_out_variant=*/true,
984 /*check_memory_overlap */ false);
985 StaticNodeInfo list_unpack_static_node_info(
986 list_unpack_node, &fn, createProcessedNodeInputs({0}), 1);
987 ProcessedNode list_unpack_pnode(
988 list_unpack_static_node_info, values.data());
989 ASSERT_EQ(list_unpack_pnode.outputs().size(), 2);
990 EXPECT_TRUE(
991 list_unpack_pnode.verify_no_memory_overlap(/* force_check*/ true));
992 }
993 {
994 std::array<IValue, 3> values = {
995 at::randn({2, 3}), at::empty({1, 3}), at::empty({4, 5})};
996 ProcessedFunction fn(
997 list_unpack_node,
998 /*enable_out_variant=*/true,
999 /*check_memory_overlap */ false);
1000 StaticNodeInfo list_unpack_static_node_info(
1001 list_unpack_node, &fn, createProcessedNodeInputs({0}), 1);
1002 ProcessedNode list_unpack_pnode(
1003 list_unpack_static_node_info, values.data());
1004 auto b = at::randn({2, 3});
1005 list_unpack_pnode.Output(0) = b;
1006 list_unpack_pnode.Output(1) = b;
1007 EXPECT_FALSE(
1008 list_unpack_pnode.verify_no_memory_overlap(/* force_check*/ true));
1009 }
1010 }
1011
1012 namespace test {
bad_add(const at::Tensor & self,int64_t b)1013 at::Tensor bad_add(const at::Tensor& self, int64_t b) {
1014 if (b == 0) {
1015 return self;
1016 }
1017 return at::native::add(self, b);
1018 }
1019
good_add(const at::Tensor & self,int64_t b)1020 at::Tensor good_add(const at::Tensor& self, int64_t b) {
1021 if (b == 0) {
1022 return self;
1023 }
1024 return at::native::add(self, b);
1025 }
1026 } // namespace test
1027
1028 // test::bad_add has the schema with incorrect alias annotation.
1029 // test::good_add has the correct alias annotation.
TORCH_LIBRARY_FRAGMENT(test,m)1030 TORCH_LIBRARY_FRAGMENT(test, m) {
1031 m.def("bad_add(Tensor self, int b=0) -> Tensor");
1032 m.def("good_add(Tensor(a) self, int b=0) -> Tensor(a)");
1033 }
TORCH_LIBRARY_IMPL(test,CPU,m)1034 TORCH_LIBRARY_IMPL(test, CPU, m) {
1035 m.impl("bad_add", ::test::bad_add);
1036 m.impl("good_add", ::test::good_add);
1037 }
1038
TEST(StaticRuntime,BadSchemaAliasInfo)1039 TEST(StaticRuntime, BadSchemaAliasInfo) {
1040 FLAGS_static_runtime_disable_debug_memory_overlap_check = true;
1041 const std::string src = R"IR(
1042 graph(%x: Tensor, %s: int):
1043 %c0 : int = prim::Constant[value=0]()
1044 %c1 : int = prim::Constant[value=1]()
1045 %a = aten::add(%x, %x, %c1)
1046 %b1 = test::bad_add(%a, %s) # b1 aliases a
1047 %t : (Tensor) = prim::TupleConstruct(%b1)
1048 return (%t)
1049 )IR";
1050
1051 const auto x1 = at::randn({2, 2});
1052 // big enough to trigger resize of the internal buffer
1053 const auto x2 = at::randn({3, 6});
1054 testStaticRuntime(src, {x1, 0}, {x2, 10});
1055 // This test doesn't pass yet. This is the corner case mentioned in Step 2 of
1056 // [Check and correct bad schema alias info at runtime]
1057 // testStaticRuntime(src, {x1, 10}, {x2, 0});
1058 FLAGS_static_runtime_disable_debug_memory_overlap_check = false;
1059 }
1060
1061 // This test repeats the last test, but with the correct schema alias
1062 // annotations
TEST(StaticRuntime,GoodSchemaAliasInfo)1063 TEST(StaticRuntime, GoodSchemaAliasInfo) {
1064 // comment out the prim::TupleConstruct repro the failure of
1065 // DCHECK(!isManagedOutputTensor(*outputs_[0]));
1066 const std::string src = R"IR(
1067 graph(%x: Tensor, %s: int):
1068 %c0 : int = prim::Constant[value=0]()
1069 %c1 : int = prim::Constant[value=1]()
1070 %a = aten::add(%x, %x, %c1)
1071 %b1 = test::good_add(%a, %s) # b1 aliases a
1072 # return (%b1)
1073 %t : (Tensor) = prim::TupleConstruct(%b1)
1074 return (%t)
1075 )IR";
1076
1077 const auto x1 = at::randn({2, 2});
1078 // big enough to trigger resize of the internal buffer
1079 const auto x2 = at::randn({3, 6});
1080 testStaticRuntime(src, {x1, 0}, {x2, 10});
1081 testStaticRuntime(src, {x1, 10}, {x2, 0});
1082 }
1083
TEST(ProcessedFunction,ProcessedFunction)1084 TEST(ProcessedFunction, ProcessedFunction) {
1085 const auto script = R"JIT(
1086 def forward(self, inp: Tensor):
1087 b = torch.sigmoid(inp).clone()
1088 c = torch.transpose(b, 0, 1)
1089 return (c)
1090 )JIT";
1091 script::Module module("module");
1092 module.define(script);
1093 torch::jit::StaticModule smodule(module);
1094
1095 Node* sigmoid_node = getNodeWithKind(smodule, "aten::sigmoid");
1096 ProcessedFunction sigmoid_fn(
1097 sigmoid_node,
1098 /*enable_out_variant=*/true,
1099 /*check_memory_overlap=*/false);
1100 EXPECT_EQ(sigmoid_fn.kind(), ProcessedFunction::Kind::kOutVariant);
1101 EXPECT_FALSE(sigmoid_fn.checkMemoryOverlap());
1102
1103 Node* transpose_node = getNodeWithKind(smodule, "aten::transpose");
1104 ProcessedFunction transpose_fn(
1105 transpose_node,
1106 /*enable_out_variant=*/true,
1107 /*check_memory_overlap=*/false);
1108 EXPECT_EQ(transpose_fn.kind(), ProcessedFunction::Kind::kNativeFunction);
1109 EXPECT_FALSE(transpose_fn.checkMemoryOverlap());
1110 }
1111
TEST(ManagedTensorRanges,NoAliases)1112 TEST(ManagedTensorRanges, NoAliases) {
1113 const std::string src = R"IR(
1114 graph(%x : Tensor):
1115 %y : Tensor = aten::mul(%x, %x)
1116 %z : Tensor = aten::mul(%y, %x)
1117 %output : Tensor = aten::mul(%z, %z)
1118 return (%output)
1119 )IR";
1120 auto graph = std::make_shared<Graph>();
1121 std::unordered_map<std::string, Value*> vmap;
1122 parseIR(src, graph.get(), vmap);
1123
1124 auto* y = vmap["y"];
1125 auto* z = vmap["z"];
1126
1127 FastSet<const Value*> managed_tensors = {y, z};
1128 AliasDb alias_db(graph);
1129 auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
1130
1131 std::vector<Node*> nodes(
1132 graph->block()->nodes().begin(), graph->block()->nodes().end());
1133 ASSERT_EQ(nodes.size(), 3);
1134
1135 EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[0]));
1136
1137 EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[1]));
1138 EXPECT_EQ(
1139 ranges.availableTensorValuesAfterNode(nodes[1]),
1140 std::vector<const Value*>{y});
1141
1142 EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[2]));
1143 EXPECT_EQ(
1144 ranges.availableTensorValuesAfterNode(nodes[2]),
1145 std::vector<const Value*>{z});
1146 }
1147
TEST(ManagedTensorRanges,AliasExtendingLifetimes)1148 TEST(ManagedTensorRanges, AliasExtendingLifetimes) {
1149 const std::string src = R"IR(
1150 graph(%x : Tensor):
1151 %y : Tensor = aten::mul(%x, %x)
1152 %y_size : int[] = aten::size(%y)
1153 %z1 : Tensor = aten::mul(%y, %y)
1154 %y_alias : Tensor = aten::view(%y, %y_size)
1155 %z2 : Tensor = aten::mul(%y_alias, %y_alias)
1156 %output : Tensor = aten::mul(%z1, %z2)
1157 return (%output)
1158 )IR";
1159 auto graph = std::make_shared<Graph>();
1160 std::unordered_map<std::string, Value*> vmap;
1161 parseIR(src, graph.get(), vmap);
1162
1163 auto* y = vmap["y"];
1164 auto* z1 = vmap["z1"];
1165 auto* z2 = vmap["z2"];
1166
1167 FastSet<const Value*> managed_tensors = {y, z1, z2};
1168 AliasDb alias_db(graph);
1169 auto ranges = ManagedTensorRanges(*graph->block(), alias_db, managed_tensors);
1170
1171 std::vector<Node*> nodes(
1172 graph->block()->nodes().begin(), graph->block()->nodes().end());
1173 ASSERT_EQ(nodes.size(), 6);
1174
1175 for (const auto i : c10::irange(4)) {
1176 EXPECT_FALSE(ranges.nodeFreesManagedTensors(nodes[i]));
1177 }
1178
1179 EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[4]));
1180 EXPECT_EQ(
1181 ranges.availableTensorValuesAfterNode(nodes[4]),
1182 std::vector<const Value*>{y});
1183
1184 EXPECT_TRUE(ranges.nodeFreesManagedTensors(nodes[5]));
1185 const auto& available_after_5 =
1186 ranges.availableTensorValuesAfterNode(nodes[5]);
1187 // We don't care about the order, so convert to set. But make sure
1188 // there are no duplicates.
1189 FastSet<const Value*> available_after_5_set(
1190 available_after_5.begin(), available_after_5.end());
1191 EXPECT_EQ(available_after_5_set.size(), available_after_5.size());
1192 EXPECT_EQ(available_after_5_set, FastSet<const Value*>({z1, z2}));
1193 }
1194
TEST(ManagedTensorRanges,LifetimeOverlap)1195 TEST(ManagedTensorRanges, LifetimeOverlap) {
1196 const std::string src = R"IR(
1197 graph(%a : Tensor):
1198 %b : Tensor = aten::mul(%a, %a)
1199 %c : Tensor = aten::mul(%b, %b)
1200 %c_size : int[] = aten::size(%c)
1201 %c_alias : Tensor = aten::view(%c, %c_size)
1202 %d : Tensor = aten::mul(%a, %a)
1203 %e : Tensor = aten::mul(%c_alias, %c_alias)
1204 %output : Tensor = aten::mul(%e, %e)
1205 return (%output)
1206 )IR";
1207 auto graph = std::make_shared<Graph>();
1208 std::unordered_map<std::string, Value*> vmap;
1209 parseIR(src, graph.get(), vmap);
1210 auto* b = vmap["b"];
1211 auto* c = vmap["c"];
1212 auto* d = vmap["d"];
1213 auto* e = vmap["e"];
1214
1215 AliasDb alias_db(graph);
1216 auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d, e});
1217 const std::vector<std::pair<Value*, Value*>> overlapping_values{
1218 {b, c}, {c, d}, {c, e}};
1219
1220 const std::vector<std::pair<Value*, Value*>> disjoint_values{{b, d}, {b, e}};
1221
1222 for (const auto& values : overlapping_values) {
1223 EXPECT_TRUE(ranges.lifetimesOverlap(values.first, values.second));
1224 EXPECT_TRUE(ranges.lifetimesOverlap(values.second, values.first));
1225 }
1226 for (const auto& values : disjoint_values) {
1227 EXPECT_FALSE(ranges.lifetimesOverlap(values.first, values.second));
1228 EXPECT_FALSE(ranges.lifetimesOverlap(values.second, values.first));
1229 }
1230 }
1231
TEST(ManagedTensorRanges,OverlappingLifetimesContainers)1232 TEST(ManagedTensorRanges, OverlappingLifetimesContainers) {
1233 const std::string src = R"IR(
1234 graph(%a : Tensor):
1235 %b : Tensor = aten::mul(%a, %a)
1236 %c : Tensor = aten::mul(%b, %b)
1237 %tuple : (Tensor, Tensor) = prim::TupleConstruct(%b, %c)
1238 %b_alias : Tensor, %c_alias : Tensor = prim::TupleUnpack(%tuple)
1239 %d : Tensor = aten::mul(%b_alias, %c_alias)
1240 %output : Tensor = aten::mul(%d, %d)
1241 return (%output)
1242 )IR";
1243 auto graph = std::make_shared<Graph>();
1244 std::unordered_map<std::string, Value*> vmap;
1245 parseIR(src, graph.get(), vmap);
1246 auto* b = vmap["b"];
1247 auto* c = vmap["c"];
1248 auto* d = vmap["d"];
1249
1250 AliasDb alias_db(graph);
1251 auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, c, d});
1252
1253 EXPECT_TRUE(ranges.lifetimesOverlap(b, c));
1254 EXPECT_TRUE(ranges.lifetimesOverlap(b, d));
1255 EXPECT_TRUE(ranges.lifetimesOverlap(c, d));
1256 }
1257
TEST(ManagedTensorRanges,OverlappingLifetimesOutputs)1258 TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) {
1259 const std::string src = R"IR(
1260 graph(%a : Tensor):
1261 %output : Tensor = aten::mul(%a, %a)
1262 %b : Tensor = aten::mul(%a, %a)
1263 return (%output)
1264 )IR";
1265 auto graph = std::make_shared<Graph>();
1266 std::unordered_map<std::string, Value*> vmap;
1267 parseIR(src, graph.get(), vmap);
1268 auto* b = vmap["b"];
1269 auto* output = vmap["output"];
1270
1271 AliasDb alias_db(graph);
1272 auto ranges = ManagedTensorRanges(*graph->block(), alias_db, {b, output});
1273
1274 EXPECT_TRUE(ranges.lifetimesOverlap(b, output));
1275 }
1276
1277 namespace {
1278
1279 // For checking the correctness of assignStorageToManageTensors, the following
1280 // conditions must hold
1281 // 1. All managed tensors are assigned to some storage group, and a tensor
1282 // may not be assigned to more than 1 storage group.
1283 // 2. Managed tensors with overlapping lifetimes should not be in the same
1284 // storage group.
1285 // 3. The number of reused tensors is >= min_reused_tensors.
checkStorageGroups(const std::vector<StorageGroup> & storage_groups,const ManagedTensorRanges & ranges,const FastMap<const Value *,at::Tensor * > & tensor_value_to_tensor,size_t min_reused_tensors)1286 void checkStorageGroups(
1287 const std::vector<StorageGroup>& storage_groups,
1288 const ManagedTensorRanges& ranges,
1289 const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor,
1290 size_t min_reused_tensors) {
1291 // Some extra bookkeeping; construct the set of managed Tensor* and
1292 // invert the tensor_value_to_tensor map. StorageGroup stores
1293 // Tensor*, so this will make everything a little easier.
1294 FastMap<at::Tensor*, const Value*> tensor_to_tensor_value;
1295 FastSet<at::Tensor*> managed_tensors;
1296 for (auto& key_value : tensor_value_to_tensor) {
1297 ASSERT_EQ(
1298 tensor_to_tensor_value.find(key_value.second),
1299 tensor_to_tensor_value.end());
1300 tensor_to_tensor_value.emplace(key_value.second, key_value.first);
1301 managed_tensors.insert(key_value.second);
1302 }
1303
1304 // Condition (1)
1305 FastSet<at::Tensor*> actual_assigned_tensors;
1306 for (const auto& storage_group : storage_groups) {
1307 for (auto* tensor : storage_group.group()) {
1308 ASSERT_EQ(
1309 actual_assigned_tensors.find(tensor), actual_assigned_tensors.end());
1310 actual_assigned_tensors.insert(tensor);
1311 }
1312 }
1313 ASSERT_EQ(actual_assigned_tensors, managed_tensors);
1314
1315 // Condition (2)
1316 size_t num_reused = 0;
1317 for (const auto& storage_group : storage_groups) {
1318 const auto& group = storage_group.group();
1319 num_reused += group.size() - 1;
1320 for (const auto i : c10::irange(group.size() - 1)) {
1321 for (const auto j : c10::irange(i + 1, group.size())) {
1322 const auto* v1 = tensor_to_tensor_value.at(group[i]);
1323 const auto* v2 = tensor_to_tensor_value.at(group[j]);
1324 EXPECT_FALSE(ranges.lifetimesOverlap(v1, v2));
1325 }
1326 }
1327 }
1328
1329 // Condition (3)
1330 EXPECT_GE(num_reused, min_reused_tensors);
1331 }
1332
1333 // A convenience function for testing assignStorageToManagedTensors. It
1334 // takes in an IR graph as well as a map from managed tensor name to tensor
1335 // value. It constructs all of the necessary data structures, invokes
1336 // assignStorageToManageTensors, and verifies correctness with
1337 // checkStorageGroups.
testAssignStorageToManagedTensors(const std::string & src,FastMap<std::string,at::Tensor> managed_tensor_name_to_tensor,size_t min_reused_tensors)1338 void testAssignStorageToManagedTensors(
1339 const std::string& src,
1340 FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor,
1341 size_t min_reused_tensors) {
1342 auto graph = std::make_shared<Graph>();
1343 std::unordered_map<std::string, Value*> vmap;
1344 parseIR(src, graph.get(), vmap);
1345
1346 FastSet<const Value*> managed_tensor_values;
1347 FastMap<const Value*, at::Tensor*> tensor_value_to_tensor;
1348
1349 for (auto& key_value : managed_tensor_name_to_tensor) {
1350 const auto& tensor_name = key_value.first;
1351 auto vmap_it = vmap.find(tensor_name);
1352 ASSERT_TRUE(vmap_it != vmap.end());
1353 managed_tensor_values.insert(vmap_it->second);
1354 tensor_value_to_tensor.emplace(vmap_it->second, &key_value.second);
1355 }
1356 ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size());
1357
1358 AliasDb alias_db(graph);
1359 auto ranges =
1360 ManagedTensorRanges(*graph->block(), alias_db, managed_tensor_values);
1361 auto groups = assignStorageToManagedTensors(
1362 graph->block()->nodes(), ranges, tensor_value_to_tensor);
1363
1364 checkStorageGroups(
1365 groups, ranges, tensor_value_to_tensor, min_reused_tensors);
1366 }
1367
1368 } // namespace
1369
TEST(AssignStorageToManagedTensors,NoAliases)1370 TEST(AssignStorageToManagedTensors, NoAliases) {
1371 const auto src = R"IR(
1372 graph(%a : Tensor):
1373 %b : Tensor = aten::mul(%a, %a)
1374 %c : Tensor = aten::mul(%b, %b)
1375 %d : Tensor = aten::mul(%c, %c)
1376 %e : Tensor = aten::mul(%b, %d)
1377 %output : Tensor = aten::mul(%e, %e)
1378 return (%output)
1379 )IR";
1380 FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor{
1381 {"b", at::randn({1})},
1382 {"c", at::randn({1})},
1383 {"d", at::randn({1})},
1384 {"e", at::randn({1})}};
1385 const size_t min_reused_tensors = 1;
1386 testAssignStorageToManagedTensors(
1387 src, std::move(managed_tensor_name_to_tensor), min_reused_tensors);
1388 }
1389
TEST(AssignStorageToManagedTensors,Aliases)1390 TEST(AssignStorageToManagedTensors, Aliases) {
1391 const auto src = R"IR(
1392 graph(%a : Tensor):
1393 %b : Tensor = aten::mul(%a, %a)
1394 %c : Tensor = aten::mul(%b, %b)
1395 %d : Tensor = aten::mul(%c, %c)
1396 %c_size : int[] = aten::size(%c)
1397 %c_alias : Tensor = aten::view(%c, %c_size)
1398 %e : Tensor = aten::mul(%b, %d)
1399 %f : Tensor = aten::mul(%c_alias, %c_alias)
1400 %output : Tensor = aten::mul(%e, %f)
1401 return (%output)
1402 )IR";
1403 FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor{
1404 {"b", at::randn({1})},
1405 {"c", at::randn({1})},
1406 {"d", at::randn({1})},
1407 {"e", at::randn({1})},
1408 {"f", at::randn({1})}};
1409 const size_t min_reused_tensors = 1;
1410 testAssignStorageToManagedTensors(
1411 src, std::move(managed_tensor_name_to_tensor), min_reused_tensors);
1412 }
1413
1414 namespace {
TORCH_LIBRARY_FRAGMENT(static_runtime_tests,m)1415 TORCH_LIBRARY_FRAGMENT(static_runtime_tests, m) {
1416 m.def(torch::schema(
1417 "static_runtime_tests::variadic_outputs(Tensor a) -> ...",
1418 at::AliasAnalysisKind::PURE_FUNCTION));
1419 }
1420 } // namespace
1421
TEST(AssignStorageToManagedTensors,MultipleUnused)1422 TEST(AssignStorageToManagedTensors, MultipleUnused) {
1423 const auto src = R"IR(
1424 graph(%a : Tensor):
1425 %z : Tensor = aten::mul(%a, %a)
1426 %out: Tensor = aten::mul(%z, %z)
1427 %x : Tensor, %y : Tensor = static_runtime_tests::variadic_outputs(%a)
1428 return (%out)
1429 )IR";
1430 FastMap<std::string, at::Tensor> managed_tensor_name_to_tensor{
1431 {"z", at::randn({1})}, {"x", at::randn({1})}, {"y", at::randn({1})}};
1432 const size_t min_reused_tensors = 1;
1433 testAssignStorageToManagedTensors(
1434 src, std::move(managed_tensor_name_to_tensor), min_reused_tensors);
1435 }
1436
1437 namespace {
testStaticModuleThrows(const std::string & src,const std::vector<IValue> & args,const std::unordered_map<std::string,IValue> & kwargs)1438 void testStaticModuleThrows(
1439 const std::string& src,
1440 const std::vector<IValue>& args,
1441 const std::unordered_map<std::string, IValue>& kwargs) {
1442 auto static_module = makeStaticModuleFromScript(src);
1443 EXPECT_THROW(static_module(args, kwargs), c10::Error);
1444 }
1445 } // namespace
1446
TEST(StaticModule,IncorrectTypesPassed)1447 TEST(StaticModule, IncorrectTypesPassed) {
1448 const std::string args_bool_script = R"JIT(
1449 def forward(self, x: bool):
1450 return x
1451 )JIT";
1452 testStaticModuleThrows(args_bool_script, {at::randn({1})}, {});
1453
1454 const std::string args_tensor_script = R"JIT(
1455 def forward(self, x: Tensor):
1456 return x
1457 )JIT";
1458 testStaticModuleThrows(args_tensor_script, {false}, {});
1459
1460 const std::string kwargs_int_script = R"JIT(
1461 def forward(self, x: bool = True):
1462 return x
1463 )JIT";
1464 testStaticModuleThrows(kwargs_int_script, {}, {{"x", at::randn({1})}});
1465
1466 const std::string kwargs_tensor_script = R"JIT(
1467 def forward(self, x: Tensor = torch.randn((1, ))):
1468 return x
1469 )JIT";
1470 testStaticModuleThrows(kwargs_tensor_script, {}, {{"x", 1.0}});
1471 }
1472
TEST(StaticModule,TooManyArgs)1473 TEST(StaticModule, TooManyArgs) {
1474 const std::string args_src = R"JIT(
1475 def forward(self, x: int):
1476 return x
1477 )JIT";
1478 testStaticModuleThrows(args_src, {0, 1}, {});
1479
1480 const std::string kwargs_src = R"JIT(
1481 def forward(self, x: int = 1):
1482 return x
1483 )JIT";
1484 testStaticModuleThrows(kwargs_src, {}, {{"y", 0}, {"x", 1}});
1485 }
1486
TEST(StaticModule,NotEnoughArgs)1487 TEST(StaticModule, NotEnoughArgs) {
1488 const std::string args_src = R"JIT(
1489 def forward(self, x: int):
1490 return x
1491 )JIT";
1492 testStaticModuleThrows(args_src, {}, {});
1493
1494 const std::string kwargs_src = R"JIT(
1495 def forward(self, *, x: int):
1496 return x
1497 )JIT";
1498 testStaticModuleThrows(kwargs_src, {}, {});
1499 }
1500
TEST(CreateOwnedRefsForSpecialValues,TopLevel)1501 TEST(CreateOwnedRefsForSpecialValues, TopLevel) {
1502 const auto src = R"IR(
1503 graph():
1504 %c: int = prim::Constant[value=42]()
1505 return (%c)
1506 )IR";
1507
1508 auto graph = getGraphFromIR(src);
1509 CreateOwnedRefsForSpecialValues(*graph);
1510 EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::create_owned_ref"));
1511 }
1512
TEST(CreateOwnedRefsForSpecialValues,ValueFromOuterScope)1513 TEST(CreateOwnedRefsForSpecialValues, ValueFromOuterScope) {
1514 const auto src = R"IR(
1515 graph(%cond: bool, %1: int):
1516 %c: int = aten::add(%1, %1)
1517 %x: int = prim::If(%c)
1518 block0():
1519 -> (%c)
1520 block1():
1521 -> (%c)
1522 return (%x)
1523 )IR";
1524
1525 auto graph = getGraphFromIR(src);
1526 CreateOwnedRefsForSpecialValues(*graph);
1527 EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::create_owned_ref"));
1528 }
1529
TEST(ForceNonEmptyOutputs,TwoSubBlocks)1530 TEST(ForceNonEmptyOutputs, TwoSubBlocks) {
1531 const auto src = R"IR(
1532 graph(%cond: bool):
1533 %lst : int[] = prim::ListConstruct()
1534 %1 : int = prim::Constant[value=1]()
1535 %2 : int = prim::Constant[value=2]()
1536 prim::If(%cond)
1537 block0():
1538 aten::append(%lst, %1)
1539 -> ()
1540 block1():
1541 aten::append(%lst, %2)
1542 -> ()
1543 return (%lst)
1544 )IR";
1545
1546 auto graph = getGraphFromIR(src);
1547 ForceNonEmptyOutputs(*graph);
1548
1549 for (auto* node : graph->nodes()) {
1550 if (node->blocks().empty()) {
1551 continue;
1552 }
1553 EXPECT_EQ(node->outputs().size(), 1);
1554 for (auto* sub_block : node->blocks()) {
1555 EXPECT_EQ(sub_block->outputs().size(), 1);
1556 }
1557 }
1558 }
1559
TEST(EliminateExtraPermuteOps,FusesSumCorrectly)1560 TEST(EliminateExtraPermuteOps, FusesSumCorrectly) {
1561 const auto src = R"JIT(
1562 def forward(self, x):
1563 y = torch.permute(x, (0, 2, 1))
1564 z = torch.sum(y, dim=-1)
1565 return z
1566 )JIT";
1567 torch::jit::Module mod("m");
1568 mod.define(src);
1569
1570 auto graph = mod.get_method("forward").graph();
1571 // turn the ListConstruct(%constant) into proper constant lists
1572 ConstantPropagation(graph);
1573 EliminateExtraPermuteOps(graph);
1574
1575 EXPECT_FALSE(hasNodeWithKind(graph, "aten::permute"));
1576 auto* sum = getNodeWithKind(graph, "aten::sum");
1577 ASSERT_NE(sum, nullptr);
1578 auto dim = toIValue(sum->input(1));
1579 ASSERT_TRUE(dim.has_value() && dim->isIntList());
1580 EXPECT_EQ(dim->toIntList(), c10::List<int64_t>{1});
1581 }
1582
TEST(EliminateExtraPermuteOps,DoesNotFuseSumWrongDim)1583 TEST(EliminateExtraPermuteOps, DoesNotFuseSumWrongDim) {
1584 const auto src = R"JIT(
1585 def forward(self, x):
1586 y = torch.permute(x, (0, 2, 1))
1587 z = torch.sum(y, dim=1)
1588 return z
1589 )JIT";
1590 torch::jit::Module mod("m");
1591 mod.define(src);
1592
1593 auto graph = mod.get_method("forward").graph();
1594 // turn the ListConstruct(%constant) into proper constant lists
1595 ConstantPropagation(graph);
1596 EliminateExtraPermuteOps(graph);
1597
1598 EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1599 }
1600
TEST(EliminateExtraPermuteOps,DoesNotFuseSumNonConstantDim)1601 TEST(EliminateExtraPermuteOps, DoesNotFuseSumNonConstantDim) {
1602 const auto src = R"JIT(
1603 def forward(self, x, dim: int):
1604 y = torch.permute(x, (0, 2, 1))
1605 z = torch.sum(y, dim=dim)
1606 return z
1607 )JIT";
1608 torch::jit::Module mod("m");
1609 mod.define(src);
1610
1611 auto graph = mod.get_method("forward").graph();
1612 // turn the ListConstruct(%constant) into proper constant lists
1613 ConstantPropagation(graph);
1614 EliminateExtraPermuteOps(graph);
1615
1616 EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1617 }
1618
TEST(EliminateExtraPermuteOps,FusesSoftmaxCorrectly)1619 TEST(EliminateExtraPermuteOps, FusesSoftmaxCorrectly) {
1620 const auto src = R"JIT(
1621 def forward(self, x):
1622 a = torch.permute(x, [0, 2, 1])
1623 b = torch.softmax(a, 2)
1624 c = torch.permute(b, [0, 2, 1])
1625 return c.clone()
1626 )JIT";
1627 torch::jit::Module mod("m");
1628 mod.define(src);
1629 auto graph = mod.get_method("forward").graph();
1630 ConstantPropagation(graph);
1631 EliminateExtraPermuteOps(graph);
1632 graph->dump();
1633
1634 EXPECT_FALSE(hasNodeWithKind(graph, "aten::permute"));
1635 auto* softmax = getNodeWithKind(graph, "aten::softmax");
1636 ASSERT_NE(softmax, nullptr);
1637 auto dim = toIValue(softmax->input(1));
1638 ASSERT_TRUE(dim.has_value() && dim->isInt());
1639 EXPECT_EQ(dim->toInt(), 1);
1640
1641 std::vector<IValue> args{at::randn({3, 4, 5})};
1642 testStaticRuntime(src, args, /*args2=*/{}, /*use_allclose=*/true);
1643 }
1644
TEST(EliminateExtraPermuteOps,DoesNotFuseSoftmaxWrongPermuteDim)1645 TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongPermuteDim) {
1646 const auto src = R"JIT(
1647 def forward(self, x):
1648 a = torch.permute(x, [0, 1, 2])
1649 b = torch.softmax(a, 2)
1650 c = torch.permute(b, [0, 1, 2])
1651 return c.clone()
1652 )JIT";
1653 torch::jit::Module mod("m");
1654 mod.define(src);
1655 auto graph = mod.get_method("forward").graph();
1656 ConstantPropagation(graph);
1657 EliminateExtraPermuteOps(graph);
1658 EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1659 }
1660
TEST(EliminateExtraPermuteOps,DoesNotFuseSoftmaxWrongSoftmaxDim)1661 TEST(EliminateExtraPermuteOps, DoesNotFuseSoftmaxWrongSoftmaxDim) {
1662 const auto src = R"JIT(
1663 def forward(self, x):
1664 a = torch.permute(x, [0, 2, 1])
1665 b = torch.softmax(a, 0)
1666 c = torch.permute(b, [0, 2, 1])
1667 return c.clone()
1668 )JIT";
1669 torch::jit::Module mod("m");
1670 mod.define(src);
1671 auto graph = mod.get_method("forward").graph();
1672 ConstantPropagation(graph);
1673 EliminateExtraPermuteOps(graph);
1674 EXPECT_TRUE(hasNodeWithKind(graph, "aten::permute"));
1675 }
1676
TEST(UseSplitAndSqueeze,Fusion)1677 TEST(UseSplitAndSqueeze, Fusion) {
1678 const auto src = R"IR(
1679 graph(%x: Tensor):
1680 %dim: int = prim::Constant[value=1]()
1681 %split_size: int = prim::Constant[value=1]()
1682 %split: Tensor[] = aten::split(%x, %split_size, %dim)
1683 %a: Tensor, %b: Tensor = prim::ListUnpack(%split)
1684 %c: Tensor = aten::squeeze(%a, %dim)
1685 %d: Tensor = aten::squeeze(%b, %dim)
1686 return (%c, %d)
1687 )IR";
1688 auto graph = getGraphFromIR(src);
1689 UseSplitAndSqueeze(graph);
1690 EXPECT_TRUE(
1691 hasNodeWithKind(graph, "static_runtime::fused_split_and_squeeze_copy"));
1692 EXPECT_FALSE(hasNodeWithKind(graph, "aten::split"));
1693 EXPECT_FALSE(hasNodeWithKind(graph, "aten::squeeze"));
1694 EXPECT_FALSE(hasNodeWithKind(graph, "prim::ListUnpack"));
1695 }
1696
TEST(EliminateNoOpSlice,IntegerStart)1697 TEST(EliminateNoOpSlice, IntegerStart) {
1698 const auto src = R"JIT(
1699 def forward(self, x: List[int]) -> List[int]:
1700 return x[0:]
1701 )JIT";
1702 torch::jit::Module mod("m");
1703 mod.define(src);
1704 auto graph = mod.get_method("forward").graph();
1705 EXPECT_TRUE(hasNodeWithKind(graph, "aten::slice"));
1706 EliminateNoOpSlice(graph);
1707 EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
1708 }
1709
TEST(EliminateNoOpSlice,NoneStart)1710 TEST(EliminateNoOpSlice, NoneStart) {
1711 const auto src = R"JIT(
1712 def forward(self, x: List[int]) -> List[int]:
1713 return x[:]
1714 )JIT";
1715 torch::jit::Module mod("m");
1716 mod.define(src);
1717 auto graph = mod.get_method("forward").graph();
1718 EliminateNoOpSlice(graph);
1719 EXPECT_FALSE(hasNodeWithKind(graph, "aten::slice"));
1720 }
1721
1722 #ifdef FBCODE_CAFFE2
1723 // FuseClampNaNToNum pass is disabled externally to avoid MSVC errors in CI
TEST(FuseClampNaNToNum,FusionHappens)1724 TEST(FuseClampNaNToNum, FusionHappens) {
1725 const auto src = R"JIT(
1726 def forward(self, x):
1727 y = torch.clamp(x, min=0.0, max=1.0)
1728 z = y.nan_to_num()
1729 return z.clone()
1730 )JIT";
1731 torch::jit::Module mod("m");
1732 mod.define(src);
1733 auto graph = mod.get_method("forward").graph();
1734 FuseClampNaNToNum(graph);
1735 EXPECT_FALSE(hasNodeWithKind(graph, "aten::clamp"));
1736 EXPECT_FALSE(hasNodeWithKind(graph, "aten::nan_to_num"));
1737 EXPECT_TRUE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
1738 // Correctness of the op is exercised in StaticRuntime.clamp_nan_to_num
1739 }
1740
TEST(FuseClampNaNToNum,NoFusion)1741 TEST(FuseClampNaNToNum, NoFusion) {
1742 const auto src1 = R"JIT(
1743 def forward(self, x, a: float, b: float):
1744 y = torch.clamp(x, a, b)
1745 z = y.nan_to_num()
1746 return z.clone()
1747 )JIT";
1748
1749 const auto src2 = R"JIT(
1750 def forward(self, x):
1751 y = torch.clamp(x, min=0.0)
1752 z = y.nan_to_num()
1753 return z.clone()
1754 )JIT";
1755
1756 const auto src3 = R"JIT(
1757 def forward(self, x):
1758 y = torch.clamp(x, max=0.0)
1759 z = y.nan_to_num()
1760 return z.clone()
1761 )JIT";
1762
1763 const auto src4 = R"JIT(
1764 def forward(self, x):
1765 y = torch.clamp(x)
1766 z = y.nan_to_num()
1767 return z.clone()
1768 )JIT";
1769
1770
1771 auto checkScript = [](const char* src) {
1772 torch::jit::Module mod("m");
1773 mod.define(src);
1774 auto graph = mod.get_method("forward").graph();
1775 FuseClampNaNToNum(graph);
1776 EXPECT_TRUE(hasNodeWithKind(graph, "aten::clamp"));
1777 EXPECT_TRUE(hasNodeWithKind(graph, "aten::nan_to_num"));
1778 EXPECT_FALSE(hasNodeWithKind(graph, "static_runtime::clamp_nan_to_num"));
1779 };
1780
1781 checkScript(src1);
1782 checkScript(src2);
1783 checkScript(src3);
1784 checkScript(src4);
1785 }
1786 #endif
1787