xref: /aosp_15_r20/external/pytorch/test/jit/test_functional_blocks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5
6import torch
7from torch.testing import FileCheck
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24class TestFunctionalBlocks(JitTestCase):
25    def test_subgraph_creation(self):
26        def fn(x, y, z):
27            x = x + 1
28            y = y + 1
29            z = z + 1
30            z.add_(2)
31            z = z * z
32            y = y * z
33            if y < 2:
34                y = y + 5
35            return x + y + z
36
37        graph = torch.jit.script(fn).graph
38        self.run_pass("create_functional_graphs", graph)
39
40        # all uses of x and y should be sunk
41        FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(
42            r"%x"
43        ).run(graph)
44        FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(
45            r"%y"
46        ).run(graph)
47
48        # Don't allow any outputs which escape scope, so there is one final addition in the graph
49        FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(
50            graph
51        )
52
53        # z + 1, z.add_(2) considered non functional, z = z * z should be considered functional
54        FileCheck().check("add").check("add_").check_not("mul").check(
55            "FunctionalGraph"
56        ).run(graph)
57