xref: /aosp_15_r20/external/pytorch/test/jit/test_autodiff_subgraph_slicing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6
7import torch
8from torch.testing._internal.common_jit import check_against_reference
9from torch.testing._internal.common_utils import (
10    enable_profiling_mode_for_profiling_tests,
11    GRAPH_EXECUTOR,
12    num_profiled_runs,
13    ProfilingMode,
14)
15
16
17# Make the helper files in test/ importable
18pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
19sys.path.append(pytorch_test_dir)
20from typing import List, Optional, Tuple
21
22from torch.testing import FileCheck
23from torch.testing._internal.jit_utils import (
24    disable_autodiff_subgraph_inlining,
25    JitTestCase,
26)
27
28
29if __name__ == "__main__":
30    raise RuntimeError(
31        "This test file is not meant to be run directly, use:\n\n"
32        "\tpython test/test_jit.py TESTNAME\n\n"
33        "instead."
34    )
35
36
37@unittest.skipIf(
38    GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
39)
40class TestAutodiffSubgraphSlicing(JitTestCase):
41    # TODO: It is better if we can test directly on graphs instead of the current
42    # end-to-end fashion.
43    def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
44        with disable_autodiff_subgraph_inlining():
45            with enable_profiling_mode_for_profiling_tests():
46                ge = torch.jit.script(fn)
47                inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
48                ge(*inputs, profile_and_replay=True)
49                return ge.graph_for(*inputs)
50
51    def assertGraphSize(self, graph, size):
52        nodes = list(
53            filter(
54                lambda n: (
55                    n.kind() != "prim::BailOut"
56                    and n.kind() != "prim::BailoutTemplate"
57                    and n.kind() != "prim::TypeCheck"
58                    and n.kind() != "prim::RequiresGradCheck"
59                ),
60                graph.nodes(),
61            )
62        )
63        self.assertEqual(len(list(nodes)), size)
64
65    def test_chunk_constant_script_ad(self):
66        @torch.jit.script
67        def func(x):
68            x1, x2 = torch.chunk(x, 2)
69            return (x1, x2)
70
71        input = torch.rand(6, 10).requires_grad_()
72        with disable_autodiff_subgraph_inlining():
73            with enable_profiling_mode_for_profiling_tests():
74                output = func(input, profile_and_replay=True)
75                FileCheck().check_not("prim::DifferentiableGraph").run(
76                    func.graph_for(input)
77                )
78
79    @unittest.skipIf(
80        GRAPH_EXECUTOR != ProfilingMode.PROFILING,
81        "This threshold is only valid for Profiling Executor",
82    )
83    def test_diff_graph_inline_threshold(self):
84        with enable_profiling_mode_for_profiling_tests():
85            NUM_RUNS = 1
86            with num_profiled_runs(NUM_RUNS):
87
88                @torch.jit.script
89                def foo(x):
90                    #  two nodes should be fused
91                    #  see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
92                    return torch.sigmoid(torch.sigmoid(x))
93
94                @torch.jit.script
95                def bar(x):
96                    #  two nodes should NOT be fused
97                    return torch.sigmoid(x)
98
99                input = torch.rand([4, 4], requires_grad=True)
100                foo(input)
101                foo(input)
102
103                bar(input)
104                bar(input)
105
106                self.assertGraphContainsExactly(
107                    foo.graph_for(input), "prim::DifferentiableGraph", 1
108                )
109                self.assertGraphContainsExactly(
110                    bar.graph_for(input), "prim::DifferentiableGraph", 0
111                )
112
113    def test_bias_as_module_attr(self):
114        with enable_profiling_mode_for_profiling_tests():
115
116            class M(torch.nn.Module):
117                def __init__(self, has_bias):
118                    super().__init__()
119                    self.ll = torch.nn.Linear(10, 10, has_bias)
120
121                def forward(self, x, y):
122                    return self.ll(x + y) * x + y
123
124            x = torch.rand(10, 10, requires_grad=True)
125            no_bias = M(False)
126            scripted_no_bias = torch.jit.script(no_bias)
127            scripted_no_bias(x, x)
128            scripted_no_bias(x, x)
129            scripted_no_bias(x, x)
130            has_bias = M(True)
131            check_against_reference(
132                self,
133                scripted_no_bias,
134                no_bias,
135                lambda x: x,
136                (
137                    x,
138                    x,
139                ),
140                check_types=False,
141            )
142            scripted_has_bias = torch.jit.script(has_bias)
143            scripted_has_bias(x, x)
144            scripted_has_bias(x, x)
145            scripted_has_bias(x, x)
146            check_against_reference(
147                self,
148                scripted_has_bias,
149                has_bias,
150                lambda x: x,
151                (
152                    x,
153                    x,
154                ),
155                check_types=False,
156            )
157
158    def test_constructed_bias(self):
159        with enable_profiling_mode_for_profiling_tests():
160
161            def method1(x, weight, b1, b2):
162                bias = b1 * b2
163                return torch.nn.functional.linear(x, weight, bias)
164
165            N = 10
166            x = torch.rand(N, N, requires_grad=True)
167            weight = torch.rand(N, N, requires_grad=True)
168            b1 = torch.rand(N, N, requires_grad=True)
169            b2 = torch.rand(N, N, requires_grad=True)
170            scripted = self.checkScript(method1, (x, weight, b1, b2))
171            # check_types requires last_graph on scripted to be set, so we just skip it
172            check_against_reference(
173                self,
174                scripted,
175                method1,
176                lambda x: x,
177                (x, weight, b1, b2),
178                check_types=False,
179            )
180
181    def test_bias_as_arg(self):
182        with enable_profiling_mode_for_profiling_tests():
183
184            def method1(x, weight, bias: Optional[torch.Tensor]):
185                return torch.nn.functional.linear(x, weight, bias).relu() + 2
186
187            N = 10
188            x = torch.rand(N, N, requires_grad=True)
189            weight = torch.rand(N, N, requires_grad=True)
190            bias = None
191            scripted = self.checkScript(method1, (x, weight, bias))
192            # check_types requires last_graph on scripted to be set, so we just skip it
193            check_against_reference(
194                self,
195                scripted,
196                method1,
197                lambda x: x,
198                (x, weight, bias),
199                check_types=False,
200            )
201            bias = torch.rand(N, N, requires_grad=True)
202            scripted = self.checkScript(method1, (x, weight, bias))
203            # check_types requires last_graph on scripted to be set, so we just skip it
204            check_against_reference(
205                self,
206                scripted,
207                method1,
208                lambda x: x,
209                (x, weight, bias),
210                check_types=False,
211            )
212
213    def test_requires_grad_for_tensor_list(self):
214        with enable_profiling_mode_for_profiling_tests():
215            # output & var_list[0] should have requires_grad set to True
216            def func(
217                input0: torch.Tensor, input1: torch.Tensor
218            ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
219                var_list = [input0, input1]
220                var = torch.cat(var_list)
221                output = var + 1.0
222                return output, var_list
223
224            jit_f = torch.jit.script(func)
225            input0 = torch.randn((2,), requires_grad=True)
226            input1 = torch.randn((2,))
227            output_ref = func(input0, input1)
228            for i in range(2):
229                output = jit_f(input0, input1)
230                assert output_ref[0].requires_grad == output[0].requires_grad
231                assert output_ref[1][0].requires_grad == output[1][0].requires_grad
232                assert output_ref[1][1].requires_grad == output[1][1].requires_grad
233
234    @unittest.skip(
235        "disable until we property handle tensor lists with undefined gradients"
236    )
237    def test_differentiable_graph_ops_requires_grad(self):
238        x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
239        y = torch.randn(8, 2, dtype=torch.float)
240
241        def t(x: torch.Tensor, y: torch.Tensor, flag: bool):
242            o = x + 1.0
243            o1 = torch.relu(o)
244            o = y + 1.5
245            o2 = torch.relu(o)
246            o3 = o1 + o2
247
248            if flag:
249                o = o1 + 1.0
250                oo1 = torch.relu(o)
251                o = o2 + 2.5
252                oo2 = torch.relu(o)
253                oo3 = oo1 + oo2
254            else:
255                o = o1 * 1.0
256                oo1 = torch.relu(o)
257                o = o2 * 2.0
258                oo2 = torch.relu(o)
259                oo3 = oo1 + oo2
260
261            return o1, o2, o3, oo1, oo2, oo3
262
263        with enable_profiling_mode_for_profiling_tests():
264            t_jit = torch.jit.script(t)
265            jit_o = t_jit(x, y, False)
266            jit_o = t_jit(x, y, False)
267            o = t(x, y, False)
268
269            FileCheck().check("prim::DifferentiableGraph").run(
270                t_jit.graph_for(x, y, False)
271            )
272            # validate the differentiableGraphOps are marking proper requires_grad
273            for oo, jit_oo in zip(o, jit_o):
274                self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
275                self.assertEqual(oo, jit_oo)
276            # one more runs to trigger fusion
277            jit_o = t_jit(x, y, False)
278            for oo, jit_oo in zip(o, jit_o):
279                self.assertEqual(oo.dtype, jit_oo.dtype)
280                self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
281                self.assertEqual(oo, jit_oo)
282
283    @unittest.skipIf(
284        GRAPH_EXECUTOR == ProfilingMode.PROFILING,
285        "Simple Executor doesn't support gradients",
286    )
287    def test_prune_grad(self):
288        @torch.jit.script
289        def t(input, bias):
290            return torch.nn.functional.relu(input + bias)
291
292        input = torch.randn(2, 8, requires_grad=True)
293        bias = torch.randn(8, requires_grad=False)  # bias does NOT require grad
294        NUM_PROFILED_RUNS = 1
295        with num_profiled_runs(NUM_PROFILED_RUNS):
296            WARMUP = 3  # 2 runs to reach backward + 1 to optimize it
297            for x in range(WARMUP):
298                o = t(input, bias)
299                o.sum().backward()
300
301            fwd_plan = list(t.get_debug_state().execution_plans.values())[0]
302            bwd_graph = list(
303                fwd_plan.code.grad_executor_states()[0].execution_plans.values()
304            )[0].graph
305            tup = next(bwd_graph.outputs())
306            self.assertEqual(len(list(tup.node().inputs())), 1)
307
308    def test_simple_merge(self):
309        # o --> o
310        def fn(x, y, z):
311            a = x * y
312            b = a * z
313            return b
314
315        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
316
317        self.assertGraphSize(graph, 1)
318        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
319
320    def test_simple_no_merge(self):
321        # o: autodiff supported. x: not autodiff supported.
322        # o --> x
323        def fn(x, y, z):
324            a = x * y
325            b = torch.zeros([abs(int(y))])
326            return a, b
327
328        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
329        g_str = str(graph)
330        FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(
331            g_str[0 : g_str.find("return")]
332        )
333        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
334
335    def test_does_not_merge_unrelated(self):
336        # o  o
337        def fn(w, x, y, z):
338            a = x * y
339            b = w * z
340            return a, b
341
342        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
343
344        self.assertGraphSize(graph, 3)
345        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
346
347    def test_merges_without_cycles(self):
348        # o --> o --> o
349        # |           ^
350        #  \_________/
351        def fn(w, x, y):
352            a = w * x
353            b = a * y
354            c = a * b
355            return c
356
357        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
358
359        self.assertGraphSize(graph, 1)
360        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
361
362    def test_merges_dense(self):
363        #   o      o
364        #   |\    /|
365        #   | \  / |
366        #   |  /\  |
367        #   vv    vv
368        #   o      o
369        def fn(x, y):
370            a, b = x.chunk(2)
371            c, d = y.chunk(2)
372            return a + c, b + d
373
374        graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
375
376        self.assertGraphSize(graph, 2)
377        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
378
379    def test_does_not_create_cycles(self):
380        # o --> x --> o
381        # |           ^
382        #  \_________/
383        def fn(w, x, y):
384            a = w * x
385            b = torch.zeros(abs(int(a)))
386            c = a * b
387            return c
388
389        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
390        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
391
392    def test_merges_up(self):
393        # o --> x     o
394        # |           ^
395        #  \_________/
396        def fn(w, x, y, z):
397            a = w * x
398            b = torch.zeros(abs(int(y)))
399            c = a * z
400            return b, c
401
402        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
403        g_str = str(graph)
404        FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
405        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
406
407    def test_merges_down(self):
408        # o     x --> o
409        # |           ^
410        #  \_________/
411        def fn(v, w, x, y):
412            a = v * w
413            b = torch.ones(int(y))
414            c = b * a
415            return a, c
416
417        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
418
419        num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3
420        # add moved down
421        g_str = str(graph)
422        FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
423        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
424
425    def test_respects_lexical_scoping(self):
426        def fn(x, k):
427            y = x * 1.1
428            if bool(k):
429                k = k + y
430            z = y * k
431            return z, k
432
433        graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
434        # We should not have combined the two multiplications into
435        # the same group; they should each be a separate DiffGraph
436        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 3)
437
438    def test_merge_respects_aliasing(self):
439        def fn(x, k, cond):
440            y = x * 1.1
441            y = y * k
442            y = y * 2.2
443            if bool(cond):
444                z1 = y[0]
445                z2 = y[1]
446                z1.add_(3)
447                out = z2 + k + 3.3
448                out = out * out
449                return out
450
451        graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1)
452        # z2 did did not get merged into the subgraph
453        FileCheck().check("prim::If").check("aten::select").check_next(
454            "aten::select"
455        ).check_next("aten::add_").check("Differentiable").run(graph)
456        self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
457
458    def test_aliased_outputs(self):
459        with enable_profiling_mode_for_profiling_tests():
460            # Case 1: aliasing between relu and t
461            # is within a DifferentiableGraph. It should be valid
462            # to merge both split_with_sizes in relu in one graph
463            input_str = """
464    graph(%a : Tensor):
465        %b : Tensor = aten::relu(%a)
466        %2 : Tensor = aten::t(%b)
467        return (%2)
468    """
469
470            graph = torch._C.parse_ir(input_str)
471            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
472            FileCheck().check("with prim::DifferentiableGraph").check(
473                "aten::relu"
474            ).check("aten::t").run(graph)
475
476            # Case 2: aliasing between relu and split_with_sizes
477            # are both outputs of a Diff graph. It should be invalid
478            # to merge both split_with_sizes in relu in one graph
479            # i.e. relu and split_with_sizes should be in different
480            # differentiable graphs
481            input_str = """
482    graph(%a : Tensor):
483        %b : Tensor = aten::relu(%a)
484        %0 : int[] = prim::Constant[value=[2, 2, 1]]()
485        %1 : int = prim::Constant[value=0]()
486        %2 : Tensor[] = aten::split_with_sizes(%b, %0, %1)
487        %3 : (Tensor[], Tensor[]) = prim::TupleConstruct(%b, %2)
488        return (%3)
489"""
490
491            graph = torch._C.parse_ir(input_str)
492            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
493            FileCheck().check("Tensor = prim::DifferentiableGraph").check(
494                "with prim::DifferentiableGraph"
495            ).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
496                graph
497            )
498
499            # Case 3: two aliased nodes in a graph.
500            # Both `split_with_sizes` should be unfused
501            input_str = """
502    graph(%a : Tensor):
503        %b : Tensor = aten::relu(%a)
504        %s1 : int[] = prim::Constant[value=[2, 2, 1]]()
505        %s2 : int[] = prim::Constant[value=[3, 1]]()
506        %1 : int = prim::Constant[value=0]()
507        %2 : Tensor[] = aten::split_with_sizes(%b, %s1, %1)
508        %3 : Tensor[] = aten::split_with_sizes(%b, %s2, %1)
509        %4 : (Tensor, Tensor[]) = prim::TupleConstruct(%b, %2, %3)
510        return (%4)
511"""
512
513            graph = torch._C.parse_ir(input_str)
514            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
515            FileCheck().check("Tensor = prim::DifferentiableGraph").check(
516                "with prim::DifferentiableGraph"
517            ).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
518                graph
519            )
520
521            # Case 4: the aliased output has a descendant
522            # Both should be unfused. Note, %3 comes before %2
523            # to test that we unfuse in the reverse topo order
524            input_str = """
525    graph(%a : Tensor):
526        %b : Tensor = aten::relu(%a)
527        %0 : int[] = prim::Constant[value=[2, 2, 1]]()
528        %1 : int = prim::Constant[value=0]()
529        %2 : Tensor = aten::t(%b)
530        %3 : Tensor = aten::relu(%2)
531        %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2)
532        return (%4)
533"""
534
535            graph = torch._C.parse_ir(input_str)
536            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
537            FileCheck().check("Tensor = prim::DifferentiableGraph").check(
538                "with prim::DifferentiableGraph"
539            ).check("Tensor = aten::relu").check_not("aten::t").run(graph)
540
541            # Case 5: multiple aliased groups
542            # Both should be unfused. Note, %3 comes before %2
543            # to test that we unfuse in the reverse topo order
544            input_str = """
545    graph(%a : Tensor):
546        %b : Tensor = aten::relu(%a)
547        %c : Tensor = aten::abs(%a)
548        %0 : int[] = prim::Constant[value=[2, 2, 1]]()
549        %1 : int = prim::Constant[value=0]()
550        %d : Tensor = aten::t(%c)
551        %2 : Tensor = aten::t(%b)
552        %3 : Tensor = aten::relu(%2)
553        %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b)
554        return (%4)
555"""
556
557            graph = torch._C.parse_ir(input_str)
558            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
559            FileCheck().check("Tensor = prim::DifferentiableGraph").check(
560                "with prim::DifferentiableGraph"
561            ).check("Tensor = aten::relu").check_not("aten::t").run(graph)
562
563    def test_has_profiled_info_aliasing_outputs(self):
564        # The expectation is that CallFunction will prevent the final profile node from
565        # getting merged into the DifferentiableGraph, and that create_autodiff_subgraphs
566        # will instead add this to the type for %4.
567        ir = """
568        graph(%a : Tensor):
569            %1 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%a)
570            %2 : Tensor = aten::relu(%1)
571            %3 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%2)
572            %4 : Tensor = aten::relu(%3)
573            %5 : Tensor = prim::CallFunction(%4)
574            %6 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%4)
575            return (%6)
576        """
577
578        graph = torch._C.parse_ir(ir)
579        torch._C._jit_pass_create_autodiff_subgraphs(graph)
580
581        for n in graph.nodes():
582            if n.kind() == "prim::DifferentiableGraph":
583                diff_graph = n.g("Subgraph")
584
585        outputs = list(diff_graph.outputs())
586        self.assertEqual(1, len(outputs))
587        output = outputs[0]
588        self.assertEqual(False, output.requiresGrad())
589
590        FileCheck().check("= prim::DifferentiableGraph").check(
591            "with prim::DifferentiableGraph"
592        ).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)
593