1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import List 6 7import torch 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24# Tests that Python slice class is supported in TorchScript 25class TestSlice(JitTestCase): 26 def test_slice_kwarg(self): 27 def slice_kwarg(x: List[int]): 28 return x[slice(1, stop=2)] 29 30 with self.assertRaisesRegex( 31 RuntimeError, "Slice does not accept any keyword arguments" 32 ): 33 torch.jit.script(slice_kwarg) 34 35 def test_slice_three_nones(self): 36 def three_nones(x: List[int]): 37 return x[slice(None, None, None)] 38 39 self.checkScript(three_nones, (range(10),)) 40 41 def test_slice_two_nones(self): 42 def two_nones(x: List[int]): 43 return x[slice(None, None)] 44 45 self.checkScript(two_nones, (range(10),)) 46 47 def test_slice_one_none(self): 48 def one_none(x: List[int]): 49 return x[slice(None)] 50 51 self.checkScript(one_none, (range(10),)) 52 53 def test_slice_stop_only(self): 54 def fn(x: List[int]): 55 return x[slice(5)] 56 57 self.checkScript(fn, (range(10),)) 58 59 def test_slice_stop_only_with_nones(self): 60 def fn(x: List[int]): 61 return x[slice(None, 5, None)] 62 63 self.checkScript(fn, (range(10),)) 64 65 def test_slice_start_stop(self): 66 def fn(x: List[int]): 67 return x[slice(1, 5)] 68 69 self.checkScript(fn, (range(10),)) 70 71 def test_slice_start_stop_with_none(self): 72 def fn(x: List[int]): 73 return x[slice(1, 5, None)] 74 75 self.checkScript(fn, (range(10),)) 76 77 def test_slice_start_stop_step(self): 78 def fn(x: List[int]): 79 return x[slice(0, 6, 2)] 80 81 self.checkScript(fn, (range(10),)) 82 83 def test_slice_string(self): 84 def fn(x: str): 85 return x[slice(None, 3, 1)] 86 87 self.checkScript(fn, ("foo_bar",)) 88 89 def test_slice_tensor(self): 90 def fn(x: torch.Tensor): 91 return x[slice(None, 3, 1)] 92 93 self.checkScript(fn, (torch.ones(10),)) 94 95 def test_slice_tensor_multidim(self): 96 def fn(x: torch.Tensor): 97 return x[slice(None, 3, 1), 0] 98 99 self.checkScript(fn, (torch.ones((10, 10)),)) 100 101 def test_slice_tensor_multidim_with_dots(self): 102 def fn(x: torch.Tensor): 103 return x[slice(None, 3, 1), ...] 104 105 self.checkScript(fn, (torch.ones((10, 10)),)) 106 107 def test_slice_as_variable(self): 108 def fn(x: List[int]): 109 a = slice(1) 110 return x[a] 111 112 self.checkScript(fn, (range(10),)) 113 114 def test_slice_stop_clipped(self): 115 def fn(x: List[int]): 116 return x[slice(1000)] 117 118 self.checkScript(fn, (range(10),)) 119 120 def test_slice_dynamic_index(self): 121 def t(x): 122 slice1 = x[0:1] 123 zero = 0 124 one = zero + 1 125 slice2 = x[zero:one] 126 return slice1 + slice2 127 128 self.checkScript(t, (torch.zeros(3, 2, 3),)) 129 130 def test_tuple_slicing(self): 131 def tuple_slice(a): 132 if bool(a): 133 b = (1, 2, 3, 4) 134 else: 135 b = (4, 3, 2, 1) 136 c = b[-4:4] 137 e = c[1:-1] 138 return e 139 140 self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True) 141 scripted_fn = torch.jit.script(tuple_slice) 142 self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3)) 143 tuple_graph = scripted_fn.graph 144 slices = tuple_graph.findAllNodes("prim::TupleConstruct") 145 num_outputs = {len(x.output().type().elements()) for x in slices} 146 # there should be only one tupleSlice with length of 2 147 self.assertTrue(num_outputs == {2}) 148 self.run_pass("lower_all_tuples", tuple_graph) 149 self.assertTrue("Tuple" not in str(tuple_graph)) 150 151 def test_module_list_slicing(self): 152 class Bar(torch.nn.Module): 153 def __init__(self, identifier: str): 154 super().__init__() 155 self.identifier = identifier 156 157 def forward(self): 158 return 0 159 160 class Foo(torch.nn.Module): 161 def __init__(self) -> None: 162 super().__init__() 163 module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")] 164 self.test = torch.nn.ModuleList(module_list) 165 166 def forward(self): 167 return self.test[::-2], self.test[1:4:] 168 169 scripted_foo = torch.jit.script(Foo()) 170 result1, result2 = scripted_foo() 171 172 self.assertEqual(len(result1), 3) 173 self.assertEqual(result1[0].identifier, "E") 174 self.assertEqual(result1[1].identifier, "C") 175 self.assertEqual(result1[2].identifier, "A") 176 177 self.assertEqual(len(result2), 3) 178 self.assertEqual(result2[0].identifier, "B") 179 self.assertEqual(result2[1].identifier, "C") 180 self.assertEqual(result2[2].identifier, "D") 181