1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5 6import torch 7from torch._C import parse_ir 8from torch.testing import FileCheck 9 10 11# Make the helper files in test/ importable 12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13sys.path.append(pytorch_test_dir) 14from torch.testing._internal.jit_utils import JitTestCase 15 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25# Tests that Python slice class is supported in TorchScript 26class TestIgnorableArgs(JitTestCase): 27 def test_slice_ignorable_args_for_slice(self): 28 graph_str = """graph(): 29 %13 : int = prim::Constant[value=0]() 30 %10 : bool = prim::Constant[value=0]() 31 %8 : NoneType = prim::Constant() 32 %0 : int = prim::Constant[value=1]() 33 %1 : int = prim::Constant[value=2]() 34 %2 : int = prim::Constant[value=3]() 35 %3 : int = prim::Constant[value=4]() 36 %4 : int = prim::Constant[value=9]() 37 %5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) 38 %6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) 39 %7 : int[][] = prim::ListConstruct(%5, %6) 40 %val.1 : Tensor = aten::tensor(%7, %8, %8, %10) 41 %16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0) 42 %20 : Tensor = aten::slice(%16, %0, %8, %0, %0) 43 return (%20)""" 44 graph = parse_ir(graph_str) 45 function = self.createFunctionFromGraph(graph) 46 function_copy = self.getExportImportCopy(function) 47 src = str(function.code) 48 # For a signature: 49 # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor 50 # We ignore trailing arguments after start=2 for dim 0 51 # and after end=1 for dim 1 52 # because in %16, %15 and %0 are default values for the schema. 53 FileCheck().check( 54 "torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)" 55 ).run(src) 56 self.assertEqual(function(), function_copy()) 57 58 def test_add_out_ignorable_args(self): 59 @torch.jit.script 60 def fn(x: torch.Tensor, y: torch.Tensor): 61 torch.add(x, y, out=y) 62 63 FileCheck().check("torch.add(x, y, out=y)").run(fn.code) 64