xref: /aosp_15_r20/external/pytorch/test/jit/test_remove_mutation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import List
6
7import torch
8from torch.testing import FileCheck
9
10
11# Make the helper files in test/ importable
12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13sys.path.append(pytorch_test_dir)
14from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
15
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25class TestRemoveMutation(JitTestCase):
26    def test_aten_inplace(self):
27        def test_not_new_alias(x):
28            y = x[0]
29            y.add_(2)
30            return y
31
32        fn = torch.jit.script(test_not_new_alias)
33        graph = fn.graph
34        self.run_pass("remove_mutation", graph)
35        FileCheck().check("aten::add_").run(graph)
36        self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2])))
37
38        def test_no_lowering():
39            x = torch.tensor([2, 2])
40            x[0] = 3
41            return x
42
43        # there is no functional equivalent of x[0] = ...
44        fn = torch.jit.script(test_no_lowering)
45        graph = fn.graph
46        self.run_pass("remove_mutation", graph)
47        FileCheck().check("aten::copy_").run(graph)
48        self.assertEqual(fn(), test_no_lowering())
49
50        def test_move_before_not_valid():
51            y = torch.tensor([2, 2])
52            z = y + 2
53            y.add_(2)
54            return y, z
55
56        fn = torch.jit.script(test_move_before_not_valid)
57        graph = fn.graph
58        self.run_pass("remove_mutation", graph)
59        FileCheck().check("aten::add_").run(graph)
60        self.assertEqual(fn(), test_move_before_not_valid())
61
62        def test_successful():
63            x = torch.tensor([2, 2])
64            x.add_(1)
65            x.add_(3)
66            y = x + 4
67            return x, y
68
69        fn = torch.jit.script(test_successful)
70        graph = fn.graph
71        self.run_pass("remove_mutation", graph)
72        FileCheck().check_not("aten::add_").run(graph)
73        self.assertEqual(test_successful(), fn())
74
75        def test_intermediary_use():
76            x = torch.tensor([2, 2])
77            x.add_(1)
78            y = x + 4
79            x.add_(3)
80            return x, y
81
82        fn = torch.jit.script(test_intermediary_use)
83        graph = fn.graph
84        FileCheck().check_count("aten::add_", 2).run(graph)
85        self.run_pass("remove_mutation", graph)
86        # Unable to remove the second add_ because of the y = x + 4 use
87        # In the future we could duplicating the value of x as a temporary and replacing
88        # its intermediary use (so long as aliasing is safe)
89        FileCheck().check_count("aten::add_", 1).run(graph)
90        self.assertEqual(test_intermediary_use(), fn())
91
92    def test_if_output(self):
93        def foo(x, cond: bool):
94            if cond:
95                y = x + 5
96            else:
97                y = x + 2
98            y.add_(4)
99            return y
100
101        out_eager = foo(torch.tensor(5), True)
102        foo_script = torch.jit.script(foo)
103        FileCheck().check("aten::add_").run(foo_script.graph)
104        self.run_pass("remove_mutation", foo_script.graph)
105        FileCheck().check_not("aten::add_").run(foo_script.graph)
106
107        self.assertEqual(out_eager, foo_script(torch.tensor(5), True))
108
109    def test_if_output_fail(self):
110        @torch.jit.script
111        def foo(cond: bool):
112            li = []
113            if cond:
114                x = torch.tensor(1)
115                li.append(x)
116            else:
117                x = torch.tensor(2)
118            y = x.add_(2)
119            return y, li
120
121        self.run_pass("inline", foo.graph)
122        self.run_pass("remove_mutation", foo.graph)
123        FileCheck().check("aten::add_").run(foo.graph)
124
125        @torch.jit.script
126        def foo(cond: bool, y):
127            if cond:
128                x = y
129            else:
130                x = torch.tensor(2)
131            z = x.add_(2)
132            return z
133
134        self.run_pass("inline", foo.graph)
135        self.run_pass("remove_mutation", foo.graph)
136        FileCheck().check("aten::add_").run(foo.graph)
137
138    def test_special_mapped_op(self):
139        def test_successful():
140            x = torch.tensor([2, 2])
141            y = torch.tensor([2, 4])
142            x.zero_()
143            y.fill_(3)
144            return x, y
145
146        fn = torch.jit.script(test_successful)
147        graph = fn.graph
148        self.run_pass("remove_mutation", graph)
149        FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
150        self.assertEqual(test_successful(), fn())
151
152        # full_like is not implemented for a tensor fill value
153
154        def test_successful():
155            x = torch.tensor([2, 2])
156            y = torch.tensor([2, 4])
157            x.fill_(y)
158            return x + x
159
160        fn = torch.jit.script(test_successful)
161        graph = fn.graph
162        self.run_pass("remove_mutation", graph)
163        FileCheck().check_not("aten::fill_").run(graph)
164
165        def normal():
166            # NOTE: For some unknown reason, the
167            # `torch._C._jit_pass_remove_mutation` call within `self.run_pass`
168            # replaces `torch.randn(..., dtype=None).normal_()` with an
169            # `aten::normal` call with dtype double, even if the default dtype
170            # is float. So we must explicitly set the dtype here
171            return torch.rand(2, 1, 3, 4, dtype=torch.float).normal_()
172
173        fn = torch.jit.script(normal)
174        graph = fn.graph
175        self.run_pass("remove_mutation", graph)
176        FileCheck().check_not("normal_").run(graph)
177        with freeze_rng_state():
178            out_eager = normal()
179        with freeze_rng_state():
180            out_script = fn()
181        self.assertEqual(out_eager, out_script)
182
183    def test_lists_append(self):
184        def successful_remove():
185            return [i for i in range(5)]  # noqa: C416
186
187        fn = torch.jit.script(successful_remove)
188        graph = fn.graph
189        self.run_pass("loop_unrolling", graph)
190        self.run_pass("remove_mutation", graph)
191        self.run_pass("constant_propagation", graph)
192        FileCheck().check("graph").check_next("Constant").check_next("return").run(
193            graph
194        )
195        self.assertEqual(successful_remove(), successful_remove())
196
197        def intermediary_use():
198            a = [1, 2]
199            b = len(a)
200            a.append(3)
201            return a
202
203        fn = torch.jit.script(intermediary_use)
204        graph = fn.graph
205        FileCheck().check("append").run(graph)
206        self.run_pass("remove_mutation", graph)
207        # it is possible to remove the append here but don't currently have the logic for it
208        FileCheck().check_not("append").run(graph)
209        self.assertEqual(intermediary_use(), fn())
210
211    def test_lists_insert(self):
212        def successful_remove():
213            a: List[int] = []
214            a.insert(0, 1)
215            a.insert(0, 2)
216            a.insert(-10, 3)
217            a.insert(-9, 4)
218            a.insert(10, 5)
219            return a
220
221        fn = torch.jit.script(successful_remove)
222        graph = fn.graph
223        torch._C._jit_pass_remove_mutation(graph)
224        torch._C._jit_pass_constant_propagation(graph)
225        FileCheck().check("graph").check_next("Constant").check_next("return").run(
226            graph
227        )
228        self.assertEqual(successful_remove(), fn())
229
230    def test_list_indexing_removal(self):
231        @torch.jit.script
232        def out_of_bounds():
233            x = [1, 2]
234            x[4] = 3
235            return x
236
237        torch._C._jit_pass_remove_mutation(out_of_bounds.graph)
238        FileCheck().check("set_item").run(out_of_bounds.graph)
239
240        @torch.jit.script
241        def unknown(y: int):
242            x = [1, 2]
243            x[y] = 3
244            return x
245
246        torch._C._jit_pass_remove_mutation(out_of_bounds.graph)
247        FileCheck().check("set_item").run(out_of_bounds.graph)
248
249        def successful():
250            x = [1, 2, 3]
251            x[0] = 4
252            x[-1] = 0
253            return x
254
255        scripted_fn = torch.jit.script(successful)
256        torch._C._jit_pass_remove_mutation(scripted_fn.graph)
257        FileCheck().check_not("set_item").run(scripted_fn.graph)
258        self.checkScript(successful, ())
259
260        def successful():
261            x = [1, 2, 3]
262            x[0] = 4
263            x[-1] = 0
264            return x
265
266        scripted_fn = torch.jit.script(successful)
267        torch._C._jit_pass_remove_mutation(scripted_fn.graph)
268        FileCheck().check_not("set_item").run(scripted_fn.graph)
269        self.checkScript(successful, ())
270
271        def successful():
272            x = [1]
273            x[-1] = 3
274            return x
275
276        scripted_fn = torch.jit.script(successful)
277        torch._C._jit_pass_remove_mutation(scripted_fn.graph)
278        FileCheck().check_not("set_item").run(scripted_fn.graph)
279        self.checkScript(successful, ())
280
281    def test_common_pytorch_list_ops(self):
282        for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
283
284            class OpMod(torch.nn.Module):
285                def __init__(self, op):
286                    super().__init__()
287                    self.op = torch_op
288
289                def forward(self):
290                    x = torch.tensor([1, 2, 3, 4])
291                    x.add_(3)
292                    y = [x, x]
293                    return self.op(y) + 3
294
295            torch_op = getattr(torch, op)
296            mod = OpMod(torch_op)
297            mod_script = torch.jit.script(mod)
298            self.run_pass("remove_mutation", mod_script.forward.graph)
299            FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
300            self.assertEqual(mod(), mod_script())
301
302            # test that the output doesnt alias the input
303            for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]:
304                result = torch_op(inputs)
305                sums = [ten.sum() for ten in result]
306
307                for inp in inputs:
308                    inp.fill_(10)
309
310                self.assertEqual(sums, [ten.sum() for ten in result])
311
312        @torch.jit.script
313        def test_multiple_uses():
314            x = torch.tensor([1, 2, 3, 4])
315            x.add_(3)
316            y = [x, x]
317            return torch.cat(y), y
318
319        self.run_pass("remove_mutation", mod_script.forward.graph)
320        FileCheck().check("aten::add_").run(test_multiple_uses.graph)
321