1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import List 6 7import torch 8from torch.testing import FileCheck 9 10 11# Make the helper files in test/ importable 12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13sys.path.append(pytorch_test_dir) 14from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase 15 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25class TestRemoveMutation(JitTestCase): 26 def test_aten_inplace(self): 27 def test_not_new_alias(x): 28 y = x[0] 29 y.add_(2) 30 return y 31 32 fn = torch.jit.script(test_not_new_alias) 33 graph = fn.graph 34 self.run_pass("remove_mutation", graph) 35 FileCheck().check("aten::add_").run(graph) 36 self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2]))) 37 38 def test_no_lowering(): 39 x = torch.tensor([2, 2]) 40 x[0] = 3 41 return x 42 43 # there is no functional equivalent of x[0] = ... 44 fn = torch.jit.script(test_no_lowering) 45 graph = fn.graph 46 self.run_pass("remove_mutation", graph) 47 FileCheck().check("aten::copy_").run(graph) 48 self.assertEqual(fn(), test_no_lowering()) 49 50 def test_move_before_not_valid(): 51 y = torch.tensor([2, 2]) 52 z = y + 2 53 y.add_(2) 54 return y, z 55 56 fn = torch.jit.script(test_move_before_not_valid) 57 graph = fn.graph 58 self.run_pass("remove_mutation", graph) 59 FileCheck().check("aten::add_").run(graph) 60 self.assertEqual(fn(), test_move_before_not_valid()) 61 62 def test_successful(): 63 x = torch.tensor([2, 2]) 64 x.add_(1) 65 x.add_(3) 66 y = x + 4 67 return x, y 68 69 fn = torch.jit.script(test_successful) 70 graph = fn.graph 71 self.run_pass("remove_mutation", graph) 72 FileCheck().check_not("aten::add_").run(graph) 73 self.assertEqual(test_successful(), fn()) 74 75 def test_intermediary_use(): 76 x = torch.tensor([2, 2]) 77 x.add_(1) 78 y = x + 4 79 x.add_(3) 80 return x, y 81 82 fn = torch.jit.script(test_intermediary_use) 83 graph = fn.graph 84 FileCheck().check_count("aten::add_", 2).run(graph) 85 self.run_pass("remove_mutation", graph) 86 # Unable to remove the second add_ because of the y = x + 4 use 87 # In the future we could duplicating the value of x as a temporary and replacing 88 # its intermediary use (so long as aliasing is safe) 89 FileCheck().check_count("aten::add_", 1).run(graph) 90 self.assertEqual(test_intermediary_use(), fn()) 91 92 def test_if_output(self): 93 def foo(x, cond: bool): 94 if cond: 95 y = x + 5 96 else: 97 y = x + 2 98 y.add_(4) 99 return y 100 101 out_eager = foo(torch.tensor(5), True) 102 foo_script = torch.jit.script(foo) 103 FileCheck().check("aten::add_").run(foo_script.graph) 104 self.run_pass("remove_mutation", foo_script.graph) 105 FileCheck().check_not("aten::add_").run(foo_script.graph) 106 107 self.assertEqual(out_eager, foo_script(torch.tensor(5), True)) 108 109 def test_if_output_fail(self): 110 @torch.jit.script 111 def foo(cond: bool): 112 li = [] 113 if cond: 114 x = torch.tensor(1) 115 li.append(x) 116 else: 117 x = torch.tensor(2) 118 y = x.add_(2) 119 return y, li 120 121 self.run_pass("inline", foo.graph) 122 self.run_pass("remove_mutation", foo.graph) 123 FileCheck().check("aten::add_").run(foo.graph) 124 125 @torch.jit.script 126 def foo(cond: bool, y): 127 if cond: 128 x = y 129 else: 130 x = torch.tensor(2) 131 z = x.add_(2) 132 return z 133 134 self.run_pass("inline", foo.graph) 135 self.run_pass("remove_mutation", foo.graph) 136 FileCheck().check("aten::add_").run(foo.graph) 137 138 def test_special_mapped_op(self): 139 def test_successful(): 140 x = torch.tensor([2, 2]) 141 y = torch.tensor([2, 4]) 142 x.zero_() 143 y.fill_(3) 144 return x, y 145 146 fn = torch.jit.script(test_successful) 147 graph = fn.graph 148 self.run_pass("remove_mutation", graph) 149 FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph) 150 self.assertEqual(test_successful(), fn()) 151 152 # full_like is not implemented for a tensor fill value 153 154 def test_successful(): 155 x = torch.tensor([2, 2]) 156 y = torch.tensor([2, 4]) 157 x.fill_(y) 158 return x + x 159 160 fn = torch.jit.script(test_successful) 161 graph = fn.graph 162 self.run_pass("remove_mutation", graph) 163 FileCheck().check_not("aten::fill_").run(graph) 164 165 def normal(): 166 # NOTE: For some unknown reason, the 167 # `torch._C._jit_pass_remove_mutation` call within `self.run_pass` 168 # replaces `torch.randn(..., dtype=None).normal_()` with an 169 # `aten::normal` call with dtype double, even if the default dtype 170 # is float. So we must explicitly set the dtype here 171 return torch.rand(2, 1, 3, 4, dtype=torch.float).normal_() 172 173 fn = torch.jit.script(normal) 174 graph = fn.graph 175 self.run_pass("remove_mutation", graph) 176 FileCheck().check_not("normal_").run(graph) 177 with freeze_rng_state(): 178 out_eager = normal() 179 with freeze_rng_state(): 180 out_script = fn() 181 self.assertEqual(out_eager, out_script) 182 183 def test_lists_append(self): 184 def successful_remove(): 185 return [i for i in range(5)] # noqa: C416 186 187 fn = torch.jit.script(successful_remove) 188 graph = fn.graph 189 self.run_pass("loop_unrolling", graph) 190 self.run_pass("remove_mutation", graph) 191 self.run_pass("constant_propagation", graph) 192 FileCheck().check("graph").check_next("Constant").check_next("return").run( 193 graph 194 ) 195 self.assertEqual(successful_remove(), successful_remove()) 196 197 def intermediary_use(): 198 a = [1, 2] 199 b = len(a) 200 a.append(3) 201 return a 202 203 fn = torch.jit.script(intermediary_use) 204 graph = fn.graph 205 FileCheck().check("append").run(graph) 206 self.run_pass("remove_mutation", graph) 207 # it is possible to remove the append here but don't currently have the logic for it 208 FileCheck().check_not("append").run(graph) 209 self.assertEqual(intermediary_use(), fn()) 210 211 def test_lists_insert(self): 212 def successful_remove(): 213 a: List[int] = [] 214 a.insert(0, 1) 215 a.insert(0, 2) 216 a.insert(-10, 3) 217 a.insert(-9, 4) 218 a.insert(10, 5) 219 return a 220 221 fn = torch.jit.script(successful_remove) 222 graph = fn.graph 223 torch._C._jit_pass_remove_mutation(graph) 224 torch._C._jit_pass_constant_propagation(graph) 225 FileCheck().check("graph").check_next("Constant").check_next("return").run( 226 graph 227 ) 228 self.assertEqual(successful_remove(), fn()) 229 230 def test_list_indexing_removal(self): 231 @torch.jit.script 232 def out_of_bounds(): 233 x = [1, 2] 234 x[4] = 3 235 return x 236 237 torch._C._jit_pass_remove_mutation(out_of_bounds.graph) 238 FileCheck().check("set_item").run(out_of_bounds.graph) 239 240 @torch.jit.script 241 def unknown(y: int): 242 x = [1, 2] 243 x[y] = 3 244 return x 245 246 torch._C._jit_pass_remove_mutation(out_of_bounds.graph) 247 FileCheck().check("set_item").run(out_of_bounds.graph) 248 249 def successful(): 250 x = [1, 2, 3] 251 x[0] = 4 252 x[-1] = 0 253 return x 254 255 scripted_fn = torch.jit.script(successful) 256 torch._C._jit_pass_remove_mutation(scripted_fn.graph) 257 FileCheck().check_not("set_item").run(scripted_fn.graph) 258 self.checkScript(successful, ()) 259 260 def successful(): 261 x = [1, 2, 3] 262 x[0] = 4 263 x[-1] = 0 264 return x 265 266 scripted_fn = torch.jit.script(successful) 267 torch._C._jit_pass_remove_mutation(scripted_fn.graph) 268 FileCheck().check_not("set_item").run(scripted_fn.graph) 269 self.checkScript(successful, ()) 270 271 def successful(): 272 x = [1] 273 x[-1] = 3 274 return x 275 276 scripted_fn = torch.jit.script(successful) 277 torch._C._jit_pass_remove_mutation(scripted_fn.graph) 278 FileCheck().check_not("set_item").run(scripted_fn.graph) 279 self.checkScript(successful, ()) 280 281 def test_common_pytorch_list_ops(self): 282 for op in ["cat", "stack", "vstack", "hstack", "dstack"]: 283 284 class OpMod(torch.nn.Module): 285 def __init__(self, op): 286 super().__init__() 287 self.op = torch_op 288 289 def forward(self): 290 x = torch.tensor([1, 2, 3, 4]) 291 x.add_(3) 292 y = [x, x] 293 return self.op(y) + 3 294 295 torch_op = getattr(torch, op) 296 mod = OpMod(torch_op) 297 mod_script = torch.jit.script(mod) 298 self.run_pass("remove_mutation", mod_script.forward.graph) 299 FileCheck().check_not("aten::add_").run(mod_script.forward.graph) 300 self.assertEqual(mod(), mod_script()) 301 302 # test that the output doesnt alias the input 303 for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]: 304 result = torch_op(inputs) 305 sums = [ten.sum() for ten in result] 306 307 for inp in inputs: 308 inp.fill_(10) 309 310 self.assertEqual(sums, [ten.sum() for ten in result]) 311 312 @torch.jit.script 313 def test_multiple_uses(): 314 x = torch.tensor([1, 2, 3, 4]) 315 x.add_(3) 316 y = [x, x] 317 return torch.cat(y), y 318 319 self.run_pass("remove_mutation", mod_script.forward.graph) 320 FileCheck().check("aten::add_").run(test_multiple_uses.graph) 321