# Owner(s): ["oncall: jit"] import os import sys from typing import List import torch from torch.testing import FileCheck # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestRemoveMutation(JitTestCase): def test_aten_inplace(self): def test_not_new_alias(x): y = x[0] y.add_(2) return y fn = torch.jit.script(test_not_new_alias) graph = fn.graph self.run_pass("remove_mutation", graph) FileCheck().check("aten::add_").run(graph) self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2]))) def test_no_lowering(): x = torch.tensor([2, 2]) x[0] = 3 return x # there is no functional equivalent of x[0] = ... fn = torch.jit.script(test_no_lowering) graph = fn.graph self.run_pass("remove_mutation", graph) FileCheck().check("aten::copy_").run(graph) self.assertEqual(fn(), test_no_lowering()) def test_move_before_not_valid(): y = torch.tensor([2, 2]) z = y + 2 y.add_(2) return y, z fn = torch.jit.script(test_move_before_not_valid) graph = fn.graph self.run_pass("remove_mutation", graph) FileCheck().check("aten::add_").run(graph) self.assertEqual(fn(), test_move_before_not_valid()) def test_successful(): x = torch.tensor([2, 2]) x.add_(1) x.add_(3) y = x + 4 return x, y fn = torch.jit.script(test_successful) graph = fn.graph self.run_pass("remove_mutation", graph) FileCheck().check_not("aten::add_").run(graph) self.assertEqual(test_successful(), fn()) def test_intermediary_use(): x = torch.tensor([2, 2]) x.add_(1) y = x + 4 x.add_(3) return x, y fn = torch.jit.script(test_intermediary_use) graph = fn.graph FileCheck().check_count("aten::add_", 2).run(graph) self.run_pass("remove_mutation", graph) # Unable to remove the second add_ because of the y = x + 4 use # In the future we could duplicating the value of x as a temporary and replacing # its intermediary use (so long as aliasing is safe) FileCheck().check_count("aten::add_", 1).run(graph) self.assertEqual(test_intermediary_use(), fn()) def test_if_output(self): def foo(x, cond: bool): if cond: y = x + 5 else: y = x + 2 y.add_(4) return y out_eager = foo(torch.tensor(5), True) foo_script = torch.jit.script(foo) FileCheck().check("aten::add_").run(foo_script.graph) self.run_pass("remove_mutation", foo_script.graph) FileCheck().check_not("aten::add_").run(foo_script.graph) self.assertEqual(out_eager, foo_script(torch.tensor(5), True)) def test_if_output_fail(self): @torch.jit.script def foo(cond: bool): li = [] if cond: x = torch.tensor(1) li.append(x) else: x = torch.tensor(2) y = x.add_(2) return y, li self.run_pass("inline", foo.graph) self.run_pass("remove_mutation", foo.graph) FileCheck().check("aten::add_").run(foo.graph) @torch.jit.script def foo(cond: bool, y): if cond: x = y else: x = torch.tensor(2) z = x.add_(2) return z self.run_pass("inline", foo.graph) self.run_pass("remove_mutation", foo.graph) FileCheck().check("aten::add_").run(foo.graph) def test_special_mapped_op(self): def test_successful(): x = torch.tensor([2, 2]) y = torch.tensor([2, 4]) x.zero_() y.fill_(3) return x, y fn = torch.jit.script(test_successful) graph = fn.graph self.run_pass("remove_mutation", graph) FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph) self.assertEqual(test_successful(), fn()) # full_like is not implemented for a tensor fill value def test_successful(): x = torch.tensor([2, 2]) y = torch.tensor([2, 4]) x.fill_(y) return x + x fn = torch.jit.script(test_successful) graph = fn.graph self.run_pass("remove_mutation", graph) FileCheck().check_not("aten::fill_").run(graph) def normal(): # NOTE: For some unknown reason, the # `torch._C._jit_pass_remove_mutation` call within `self.run_pass` # replaces `torch.randn(..., dtype=None).normal_()` with an # `aten::normal` call with dtype double, even if the default dtype # is float. So we must explicitly set the dtype here return torch.rand(2, 1, 3, 4, dtype=torch.float).normal_() fn = torch.jit.script(normal) graph = fn.graph self.run_pass("remove_mutation", graph) FileCheck().check_not("normal_").run(graph) with freeze_rng_state(): out_eager = normal() with freeze_rng_state(): out_script = fn() self.assertEqual(out_eager, out_script) def test_lists_append(self): def successful_remove(): return [i for i in range(5)] # noqa: C416 fn = torch.jit.script(successful_remove) graph = fn.graph self.run_pass("loop_unrolling", graph) self.run_pass("remove_mutation", graph) self.run_pass("constant_propagation", graph) FileCheck().check("graph").check_next("Constant").check_next("return").run( graph ) self.assertEqual(successful_remove(), successful_remove()) def intermediary_use(): a = [1, 2] b = len(a) a.append(3) return a fn = torch.jit.script(intermediary_use) graph = fn.graph FileCheck().check("append").run(graph) self.run_pass("remove_mutation", graph) # it is possible to remove the append here but don't currently have the logic for it FileCheck().check_not("append").run(graph) self.assertEqual(intermediary_use(), fn()) def test_lists_insert(self): def successful_remove(): a: List[int] = [] a.insert(0, 1) a.insert(0, 2) a.insert(-10, 3) a.insert(-9, 4) a.insert(10, 5) return a fn = torch.jit.script(successful_remove) graph = fn.graph torch._C._jit_pass_remove_mutation(graph) torch._C._jit_pass_constant_propagation(graph) FileCheck().check("graph").check_next("Constant").check_next("return").run( graph ) self.assertEqual(successful_remove(), fn()) def test_list_indexing_removal(self): @torch.jit.script def out_of_bounds(): x = [1, 2] x[4] = 3 return x torch._C._jit_pass_remove_mutation(out_of_bounds.graph) FileCheck().check("set_item").run(out_of_bounds.graph) @torch.jit.script def unknown(y: int): x = [1, 2] x[y] = 3 return x torch._C._jit_pass_remove_mutation(out_of_bounds.graph) FileCheck().check("set_item").run(out_of_bounds.graph) def successful(): x = [1, 2, 3] x[0] = 4 x[-1] = 0 return x scripted_fn = torch.jit.script(successful) torch._C._jit_pass_remove_mutation(scripted_fn.graph) FileCheck().check_not("set_item").run(scripted_fn.graph) self.checkScript(successful, ()) def successful(): x = [1, 2, 3] x[0] = 4 x[-1] = 0 return x scripted_fn = torch.jit.script(successful) torch._C._jit_pass_remove_mutation(scripted_fn.graph) FileCheck().check_not("set_item").run(scripted_fn.graph) self.checkScript(successful, ()) def successful(): x = [1] x[-1] = 3 return x scripted_fn = torch.jit.script(successful) torch._C._jit_pass_remove_mutation(scripted_fn.graph) FileCheck().check_not("set_item").run(scripted_fn.graph) self.checkScript(successful, ()) def test_common_pytorch_list_ops(self): for op in ["cat", "stack", "vstack", "hstack", "dstack"]: class OpMod(torch.nn.Module): def __init__(self, op): super().__init__() self.op = torch_op def forward(self): x = torch.tensor([1, 2, 3, 4]) x.add_(3) y = [x, x] return self.op(y) + 3 torch_op = getattr(torch, op) mod = OpMod(torch_op) mod_script = torch.jit.script(mod) self.run_pass("remove_mutation", mod_script.forward.graph) FileCheck().check_not("aten::add_").run(mod_script.forward.graph) self.assertEqual(mod(), mod_script()) # test that the output doesnt alias the input for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]: result = torch_op(inputs) sums = [ten.sum() for ten in result] for inp in inputs: inp.fill_(10) self.assertEqual(sums, [ten.sum() for ten in result]) @torch.jit.script def test_multiple_uses(): x = torch.tensor([1, 2, 3, 4]) x.add_(3) y = [x, x] return torch.cat(y), y self.run_pass("remove_mutation", mod_script.forward.graph) FileCheck().check("aten::add_").run(test_multiple_uses.graph)