xref: /aosp_15_r20/external/pytorch/test/inductor/test_inplacing_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3from typing import List
4
5import torch
6import torch._inductor.config as inductor_config
7from functorch import make_fx
8from torch import Tensor
9from torch._dynamo.utils import counters
10from torch._higher_order_ops.auto_functionalize import (
11    auto_functionalized,
12    auto_functionalized_v2,
13)
14from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core
15from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
16from torch.testing._internal.common_utils import (
17    instantiate_parametrized_tests,
18    IS_LINUX,
19    parametrize,
20    subtest,
21)
22from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
23from torch.testing._internal.logging_utils import logs_to_string
24
25
26aten = torch.ops.aten
27
28
29const = torch.tensor(0.0)
30device = GPU_TYPE
31
32
33def num_reinplacing_failures():
34    return counters["inductor"]["possibly_missed_reinplacing_opportunities"]
35
36
37@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
38def sin(x: torch.Tensor, result: torch.Tensor) -> None:
39    result.copy_(x.sin())
40
41
42@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
43def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
44    out_sin.copy_(x.sin())
45    out_cos.copy_(x.cos())
46
47
48if HAS_GPU:
49    import triton
50    import triton.language as tl
51
52    @triton.jit
53    def sin_kernel(
54        in_ptr0,
55        out_ptr,
56        n_elements,
57        BLOCK_SIZE: "tl.constexpr",
58    ):
59        pid = tl.program_id(axis=0)
60        block_start = pid * BLOCK_SIZE
61        offsets = block_start + tl.arange(0, BLOCK_SIZE)
62        mask = offsets < n_elements
63        x = tl.load(in_ptr0 + offsets, mask=mask)
64        output = tl.sin(x)
65        tl.store(out_ptr + offsets, output, mask=mask)
66
67    def sin_triton(x, out):
68        n_elements = x.numel()
69        sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
70
71else:
72
73    def sin_triton(x, out):
74        return
75
76
77@torch.library.custom_op("test_view::boo", mutates_args={"x"})
78def boo(x: torch.Tensor) -> None:
79    x.sin_()
80
81
82class TestReinplacingPassCorrectness(InductorTestCase):
83    def setUp(self):
84        counters.clear()
85        return super().setUp()
86
87    def _test(self, f):
88        nf = torch.compile(f)
89        inp = (
90            torch.randn(4, device=device),
91            torch.ones(2, device=device, dtype=torch.int),
92        )
93        inp2 = (inp[0].clone(), inp[1].clone())
94        self.assertEqual(f(*inp), nf(*inp2))
95        self.assertEqual(inp, inp2)
96
97    def test_dont_modify_live(self):
98        def f(x, y):
99            x = x.cos()
100            x2 = x.index_put((y,), const)
101            return x2, x
102
103        self._test(f)
104
105    def test_dont_modify_view_of_live(self):
106        def f(x, y):
107            x = x.cos()
108            x2 = aten.alias(x)
109            x2 = x2.index_put((y,), const)
110            y = x2 + x.cos()
111            return y
112
113        self._test(f)
114
115    def test_dont_modify_input(self):
116        def f(x, y):
117            return x.index_put((y,), const)
118
119        self._test(f)
120
121    def test_should_modify_inner(self):
122        def f(x, y):
123            x = x.cos()
124            x = x.index_put((y,), const)
125            return x
126
127        self._test(f)
128
129    def test_should_modify_input(self):
130        def f(x, y):
131            x = x.index_put_((y,), const)
132            return x
133
134        self._test(f)
135
136    def test_counters_functionalize_old(self):
137        counters.clear()
138
139        def f(x):
140            out = torch.empty_like(x)
141            _, new_out = auto_functionalized(sin._opoverload, x=x, result=out)
142            y = out * new_out
143            return new_out, y
144
145        x = torch.randn(3, device=device)
146        gm = make_fx(f, tracing_mode="fake")(x)
147        reinplace_inplaceable_ops_core(gm.graph)
148
149        # We shouldn't have been able to reinplace `out` because it was used after
150        # auto_functionalized. Note that this usually doesn't happen in practice;
151        # we're artificially creating this example to test the counter.
152        # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE
153        self.assertEqual(num_reinplacing_failures(), 1)
154
155    def test_counters_functionalize_v2(self):
156        counters.clear()
157
158        def f(x):
159            out = torch.empty_like(x)
160            _, new_out = auto_functionalized_v2(
161                sin._opoverload,
162                x=x,
163                _result_base_index=0,
164                _result_size=(3,),
165                _result_stride=(1,),
166                _result_storage_offset=0,
167                _all_bases=[out],
168            )
169            y = out * new_out
170            return new_out, y
171
172        x = torch.randn(3, device=device)
173        gm = make_fx(f, tracing_mode="fake")(x)
174        reinplace_inplaceable_ops_core(gm.graph)
175
176        # We shouldn't have been able to reinplace `out` because it was used after
177        # auto_functionalized. Note that this usually doesn't happen in practice;
178        # we're artificially creating this example to test the counter.
179        # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE
180        self.assertEqual(num_reinplacing_failures(), 1)
181
182    def get_not_inplaced_count(self, graph):
183        counter = 0
184        auto_functionalized_found = False
185        for node in graph.nodes:
186            if (node.target == torch.ops.higher_order.auto_functionalized) or (
187                node.target == torch.ops.higher_order.auto_functionalized_v2
188            ):
189                auto_functionalized_found = True
190                counter += len(node.meta["only_clone_these_tensors"])
191        assert auto_functionalized_found
192        return counter
193
194    def test_view_inplaced_functionalize_v2(self):
195        def f(arg0_1):
196            select = torch.ops.aten.select.int(arg0_1, 0, 0)
197            auto_functionalized = auto_functionalized_v2(
198                torch.ops.test_view.boo.default,
199                _x_base_index=0,
200                _x_size=(3,),
201                _x_stride=(1,),
202                _x_storage_offset=0,
203                _all_bases=[arg0_1],
204            )
205            getitem_1 = auto_functionalized[1]
206            copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1)
207            return ()
208
209        x1 = torch.randn(3, device=device)
210        gm = make_fx(f, tracing_mode="fake")(x1)
211        reinplace_inplaceable_ops_core(gm.graph)
212
213        self.assertEqual(self.get_not_inplaced_count(gm.graph), 0)
214
215    # introduce a view another_view that is used `after` the copy
216    def test_view_inplaced2_functionalize_v2(self):
217        def f(arg0_1):
218            select = torch.ops.aten.select.int(arg0_1, 0, 0)
219            another_view = arg0_1[2]
220            auto_functionalized = auto_functionalized_v2(
221                torch.ops.test_view.boo.default,
222                _x_base_index=0,
223                _x_size=(3,),
224                _x_stride=(1,),
225                _x_storage_offset=0,
226                _all_bases=[arg0_1],
227            )
228            getitem_1 = auto_functionalized[1]
229            copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1)
230            return another_view
231
232        x1 = torch.randn(3, device=device)
233        gm = make_fx(f, tracing_mode="fake")(x1)
234        reinplace_inplaceable_ops_core(gm.graph)
235
236        self.assertEqual(self.get_not_inplaced_count(gm.graph), 0)
237
238    # introduce a view another_view that is used `before` the copy
239    def test_views_not_inplaced_functionalize_v2(self):
240        def f(arg0_1):
241            select = torch.ops.aten.select.int(arg0_1, 0, 0)
242            another_view = arg0_1[2]
243            auto_functionalized = auto_functionalized_v2(
244                torch.ops.test_view.boo.default,
245                _x_base_index=0,
246                _x_size=(3,),
247                _x_stride=(1,),
248                _x_storage_offset=0,
249                _all_bases=[arg0_1],
250            )
251            getitem_1 = auto_functionalized[1]
252            use_another_view = another_view * 10
253            copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1)
254            return use_another_view
255
256        x1 = torch.randn(3, device=device)
257        gm = make_fx(f, tracing_mode="fake")(x1)
258        reinplace_inplaceable_ops_core(gm.graph)
259
260        self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
261
262    # a view over input without copy node, inplace not allowed
263    def test_views_not_inplaced2_functionalize_v2(self):
264        def f(arg0_1):
265            select = torch.ops.aten.select.int(arg0_1, 0, 0)
266            another_view = arg0_1[2]
267            auto_functionalized = auto_functionalized_v2(
268                torch.ops.test_view.boo.default,
269                _x_base_index=0,
270                _x_size=(3,),
271                _x_stride=(1,),
272                _x_storage_offset=0,
273                _all_bases=[arg0_1],
274            )
275            getitem_1 = auto_functionalized[1]
276            return
277
278        x1 = torch.randn(3, device=device)
279        gm = make_fx(f, tracing_mode="fake")(x1)
280        reinplace_inplaceable_ops_core(gm.graph)
281
282        self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
283
284    # no copy nodes, view over local, with a use for another view
285    def test_views_not_inplaced3_functionalize_v2(self):
286        def f(arg0_1):
287            a = torch.ones(10)
288            another_view = a[2]
289            auto_functionalized = auto_functionalized_v2(
290                torch.ops.test_view.boo.default,
291                _x_base_index=0,
292                _x_size=(),
293                _x_stride=(),
294                _x_storage_offset=0,
295                _all_bases=[a],
296            )
297            getitem_1 = auto_functionalized[1]
298            return another_view
299
300        x1 = torch.randn(3, device=device)
301        gm = make_fx(f, tracing_mode="fake")(x1)
302        reinplace_inplaceable_ops_core(gm.graph)
303
304        self.assertEqual(self.get_not_inplaced_count(gm.graph), 1)
305
306    def test_multi_output_intermediate(self):
307        for requires_grad in [False, True]:
308            for enable_v2 in [False, True]:
309                with inductor_config.patch(
310                    {"enable_auto_functionalized_v2": enable_v2}
311                ):
312                    counters.clear()
313
314                    def f(x):
315                        out1 = torch.empty_like(x)
316                        out2 = torch.empty_like(x)
317                        sin_cos(x, out1, out2)
318                        return out1, out2, x**2
319
320                    x = torch.randn(3, device=device, requires_grad=requires_grad)
321                    res1, res2, _ = torch.compile(f)(x)
322                    self.assertEqual(res1, x.sin())
323                    self.assertEqual(res2, x.cos())
324                    self.assertEqual(num_reinplacing_failures(), 0)
325
326    def test_multiple_mutations(self):
327        counters.clear()
328
329        def f(x, out):
330            sin(x, out)
331            sin(out, out)
332            sin(out, out)
333            return out
334
335        x = torch.randn(3, device=device)
336        out = torch.randn(3, device=device)
337        result = torch.compile(f)(x, out)
338        self.assertEqual(result, x.sin().sin().sin())
339        self.assertEqual(result, out)
340        self.assertEqual(num_reinplacing_failures(), 0)
341
342    def test_multiple_intermediate(self):
343        counters.clear()
344
345        def f(x):
346            out = torch.empty_like(x)
347            sin(x, out)
348            sin(out, out)
349            sin(out, out)
350            return out
351
352        x = torch.randn(3, device=device)
353        result = torch.compile(f)(x)
354        self.assertEqual(result, x.sin().sin().sin())
355        self.assertEqual(num_reinplacing_failures(), 0)
356
357    def test_lists_functionalize_v2(self):
358        with inductor_config.patch({"enable_auto_functionalized_v2": True}):
359
360            @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"})
361            def mutate_op(y: List[Tensor]) -> None:
362                y[0].add_(2)
363                y[1].add_(3)
364
365            @torch.compile(fullgraph=True, dynamic=False, backend="inductor")
366            def f(b):
367                mutate_op([b[0], b[1]])
368
369            x1 = torch.tensor([0.3, 0.4], device=device)
370            log_stream, ctx = logs_to_string(
371                "torch._inductor.compile_fx", "post_grad_graphs"
372            )
373            with ctx():
374                torch.compile(f, backend="inductor", fullgraph=True)(x1)
375            post_grad_graphs = "\n".join(
376                log_stream.getvalue().strip().split("\n")[3:]
377            ).strip()
378
379            # We can inplace the base y. no clones emitted.
380            self.assertEqual(num_reinplacing_failures(), 0)
381            self.assertEqual(post_grad_graphs.count("aten.clone"), 0)
382
383    def test_lists_old_functionalize(self):
384        with inductor_config.patch({"enable_auto_functionalized_v2": False}):
385
386            @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"})
387            def mutate_op(y: List[Tensor]) -> None:
388                y[0].add_(2)
389                y[1].add_(3)
390
391            @torch.compile(fullgraph=True, dynamic=False, backend="inductor")
392            def f(b):
393                mutate_op([b[0], b[1]])
394
395            x1 = torch.tensor([0.3, 0.4], device=device)
396            log_stream, ctx = logs_to_string(
397                "torch._inductor.compile_fx", "post_grad_graphs"
398            )
399            with ctx():
400                torch.compile(f, backend="inductor", fullgraph=True)(x1)
401            post_grad_graphs = "\n".join(
402                log_stream.getvalue().strip().split("\n")[3:]
403            ).strip()
404
405            # Can't reinplace on views yet (1 for the "entire list" failing to reinplace)
406            self.assertEqual(num_reinplacing_failures(), 1)
407
408            # Both list inputs failed to reinplace. So we should have emitted clones for them.
409            self.assertEqual(post_grad_graphs.count("aten.clone"), 2)
410
411    @parametrize(
412        "factory_op",
413        [
414            subtest(torch.ones_like, name="ones_like"),
415            subtest(torch.empty_like, name="empty_like"),
416        ],
417    )
418    @parametrize(
419        "sin_op",
420        [
421            subtest(sin, name="sin_op"),
422            subtest(sin_triton, name="sin_triton"),
423        ],
424    )
425    def test_partitioner_recomputes_factory(self, factory_op, sin_op):
426        class MySin(torch.autograd.Function):
427            @staticmethod
428            def forward(ctx, x):
429                out = factory_op(x)
430                sin_op(x, out)
431                ctx.save_for_backward(out)
432                return out
433
434            @staticmethod
435            def backward(ctx, grad):
436                (saved,) = ctx.saved_tensors
437                out = factory_op(grad)
438                sin_op(saved, out)
439                return out
440
441        @torch.compile(backend="inductor")
442        def f(x):
443            return MySin.apply(x)
444
445        x = torch.randn(3, requires_grad=True, device=device)
446        y = f(x)
447        self.assertEqual(num_reinplacing_failures(), 0)
448
449
450instantiate_parametrized_tests(TestReinplacingPassCorrectness)
451
452
453if __name__ == "__main__":
454    if IS_LINUX and HAS_GPU:
455        run_tests(needs="filelock")
456