1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch._C import parse_ir 5from torch.testing._internal.common_utils import TemporaryFileName 6from torch.testing._internal.jit_utils import JitTestCase 7 8 9if __name__ == "__main__": 10 raise RuntimeError( 11 "This test file is not meant to be run directly, use:\n\n" 12 "\tpython test/test_jit.py TESTNAME\n\n" 13 "instead." 14 ) 15 16 17class TestAliasAnalysis(JitTestCase): 18 def test_becomes_wildcard_annotations(self): 19 graph_str = """ 20 graph(%a.1 : Tensor, %b.1 : Tensor): 21 %11 : NoneType = prim::Constant() 22 %8 : int = prim::Constant[value=0]() 23 %7 : int = prim::Constant[value=1]() 24 %x.1 : Tensor = aten::add(%a.1, %b.1, %7) 25 %y.1 : Tensor[] = aten::split(%x.1, %7, %8) 26 return () 27 """ 28 graph = parse_ir(graph_str) 29 alias_db = graph.alias_db() 30 split_node = graph.findNode("aten::split") 31 # split input enters wildcard set, list initalized as containing wildcard set 32 self.assertTrue( 33 alias_db.may_contain_alias(next(split_node.inputs()), split_node.output()) 34 ) 35 # because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs) 36 self.assertTrue( 37 alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs())) 38 ) 39 40 def test_nested_list_construct_not_wildcard(self): 41 @torch.jit.script 42 def foo(x): 43 y = torch.rand([2, 2]) 44 return [y] 45 46 graph = foo.graph 47 graph.alias_db() 48 alias_db = graph.alias_db() 49 ten_construct = graph.findNode("aten::rand").output() 50 output = next(graph.outputs()) 51 self.assertTrue(alias_db.may_contain_alias(ten_construct, output)) 52 self.assertFalse( 53 alias_db.may_contain_alias(next(graph.inputs()), ten_construct) 54 ) 55 56 def test_recursive_calls(self): 57 @torch.jit.script 58 def foo(x, y): 59 x.add_(1) 60 return x + y 61 62 @torch.jit.script 63 def caller(): 64 a = torch.rand([2, 2]) 65 b = torch.ones([2, 2]) 66 out1 = foo(a, b) 67 c = torch.rand([1]) 68 d = torch.ones([2]) 69 out2 = foo(d, c) 70 return out1, out2 71 72 isFrozen = False 73 descend_function_calls = True 74 alias_db = caller.graph.alias_db(isFrozen, descend_function_calls) 75 func_calls = caller.graph.findAllNodes("prim::CallFunction") 76 self.assertEqual(len(func_calls), 2) 77 for node in func_calls: 78 inps = list(node.inputs()) 79 self.assertTrue(alias_db.has_writers(inps[1])) 80 self.assertFalse(alias_db.has_writers(inps[2])) 81 82 class Mod(torch.nn.Module): 83 def forward(self): 84 a = torch.rand([2, 2]) 85 b = torch.ones([2, 2]) 86 out1 = self.foo2(a, b) 87 c = torch.rand([1]) 88 d = torch.ones([2]) 89 out2 = self.foo2(d, c) 90 return out1, out2 91 92 def foo2(self, x, y): 93 x.add_(1) 94 return x + y 95 96 mod = torch.jit.script(Mod()) 97 alias_db = mod.graph.alias_db(isFrozen, descend_function_calls) 98 func_calls = mod.graph.findAllNodes("prim::CallMethod") 99 self.assertEqual(len(func_calls), 2) 100 for node in func_calls: 101 inps = list(node.inputs()) 102 self.assertTrue(alias_db.has_writers(inps[1])) 103 self.assertFalse(alias_db.has_writers(inps[2])) 104 105 def test_multiple_compilation_units(self): 106 # This is a repro of an internal issue we saw. 107 # Here, we have a large number (40) of modules each with the same name (MyModuleCUTest). 108 # AliasDB uses some hash tables that hash on types; each of these 40 modules are not 109 # identical because they have different compilation units, but they have the same name. 110 # Therefore, if we hash only on the module name (which we previously did), we will have 111 # hash collisions for all of these module types. 112 # 113 # flat_hash_map has very bad performance (exponential) for this hash collision behavior. 114 # This OOMs prior to the fix. 115 N = 40 116 117 class MultiTmpFile: 118 def __init__(self, N): 119 self.N = N 120 self.ctxs = [ 121 TemporaryFileName(mode="w", suffix=".py") for _ in range(N) 122 ] 123 124 def __enter__(self): 125 return [x.__enter__() for x in self.ctxs] 126 127 def __exit__(self, exc_type, exc_value, traceback): 128 return [x.__exit__(exc_type, exc_value, traceback) for x in self.ctxs] 129 130 class ModuleWrapper(torch.nn.Module): 131 def __init__(self, module_list): 132 super().__init__() 133 self.module_list = module_list 134 135 def forward(self, x): 136 for mod in self.module_list: 137 x = mod(x) 138 return x 139 140 with MultiTmpFile(N) as fnames: 141 module_list = torch.nn.ModuleList() 142 global MyModuleCUTest 143 144 class MyModuleCUTest(torch.nn.Module): 145 def forward(self, x): 146 return x + 2 147 148 for _, fname in enumerate(fnames): 149 mod = torch.jit.script(MyModuleCUTest()) 150 torch.jit.save(mod, fname) 151 loaded_mod = torch.jit.load(fname) 152 module_list.append(loaded_mod) 153 154 mod = ModuleWrapper(module_list) 155 mod = torch.jit.script(mod) 156 mod(torch.zeros((2, 2))) 157