1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5 6import torch 7from torch.testing import FileCheck 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24class TestFunctionalBlocks(JitTestCase): 25 def test_subgraph_creation(self): 26 def fn(x, y, z): 27 x = x + 1 28 y = y + 1 29 z = z + 1 30 z.add_(2) 31 z = z * z 32 y = y * z 33 if y < 2: 34 y = y + 5 35 return x + y + z 36 37 graph = torch.jit.script(fn).graph 38 self.run_pass("create_functional_graphs", graph) 39 40 # all uses of x and y should be sunk 41 FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check( 42 r"%x" 43 ).run(graph) 44 FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check( 45 r"%y" 46 ).run(graph) 47 48 # Don't allow any outputs which escape scope, so there is one final addition in the graph 49 FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run( 50 graph 51 ) 52 53 # z + 1, z.add_(2) considered non functional, z = z * z should be considered functional 54 FileCheck().check("add").check("add_").check_not("mul").check( 55 "FunctionalGraph" 56 ).run(graph) 57