1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom typing import List 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 11*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 17*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 18*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 19*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 20*da0073e9SAndroid Build Coastguard Worker "instead." 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker# Tests that Python slice class is supported in TorchScript 25*da0073e9SAndroid Build Coastguard Workerclass TestSlice(JitTestCase): 26*da0073e9SAndroid Build Coastguard Worker def test_slice_kwarg(self): 27*da0073e9SAndroid Build Coastguard Worker def slice_kwarg(x: List[int]): 28*da0073e9SAndroid Build Coastguard Worker return x[slice(1, stop=2)] 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 31*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Slice does not accept any keyword arguments" 32*da0073e9SAndroid Build Coastguard Worker ): 33*da0073e9SAndroid Build Coastguard Worker torch.jit.script(slice_kwarg) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def test_slice_three_nones(self): 36*da0073e9SAndroid Build Coastguard Worker def three_nones(x: List[int]): 37*da0073e9SAndroid Build Coastguard Worker return x[slice(None, None, None)] 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker self.checkScript(three_nones, (range(10),)) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def test_slice_two_nones(self): 42*da0073e9SAndroid Build Coastguard Worker def two_nones(x: List[int]): 43*da0073e9SAndroid Build Coastguard Worker return x[slice(None, None)] 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker self.checkScript(two_nones, (range(10),)) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def test_slice_one_none(self): 48*da0073e9SAndroid Build Coastguard Worker def one_none(x: List[int]): 49*da0073e9SAndroid Build Coastguard Worker return x[slice(None)] 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker self.checkScript(one_none, (range(10),)) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def test_slice_stop_only(self): 54*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 55*da0073e9SAndroid Build Coastguard Worker return x[slice(5)] 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (range(10),)) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def test_slice_stop_only_with_nones(self): 60*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 61*da0073e9SAndroid Build Coastguard Worker return x[slice(None, 5, None)] 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (range(10),)) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_slice_start_stop(self): 66*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 67*da0073e9SAndroid Build Coastguard Worker return x[slice(1, 5)] 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (range(10),)) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def test_slice_start_stop_with_none(self): 72*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 73*da0073e9SAndroid Build Coastguard Worker return x[slice(1, 5, None)] 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (range(10),)) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def test_slice_start_stop_step(self): 78*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 79*da0073e9SAndroid Build Coastguard Worker return x[slice(0, 6, 2)] 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (range(10),)) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker def test_slice_string(self): 84*da0073e9SAndroid Build Coastguard Worker def fn(x: str): 85*da0073e9SAndroid Build Coastguard Worker return x[slice(None, 3, 1)] 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("foo_bar",)) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def test_slice_tensor(self): 90*da0073e9SAndroid Build Coastguard Worker def fn(x: torch.Tensor): 91*da0073e9SAndroid Build Coastguard Worker return x[slice(None, 3, 1)] 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.ones(10),)) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def test_slice_tensor_multidim(self): 96*da0073e9SAndroid Build Coastguard Worker def fn(x: torch.Tensor): 97*da0073e9SAndroid Build Coastguard Worker return x[slice(None, 3, 1), 0] 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.ones((10, 10)),)) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def test_slice_tensor_multidim_with_dots(self): 102*da0073e9SAndroid Build Coastguard Worker def fn(x: torch.Tensor): 103*da0073e9SAndroid Build Coastguard Worker return x[slice(None, 3, 1), ...] 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.ones((10, 10)),)) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def test_slice_as_variable(self): 108*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 109*da0073e9SAndroid Build Coastguard Worker a = slice(1) 110*da0073e9SAndroid Build Coastguard Worker return x[a] 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (range(10),)) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def test_slice_stop_clipped(self): 115*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 116*da0073e9SAndroid Build Coastguard Worker return x[slice(1000)] 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (range(10),)) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def test_slice_dynamic_index(self): 121*da0073e9SAndroid Build Coastguard Worker def t(x): 122*da0073e9SAndroid Build Coastguard Worker slice1 = x[0:1] 123*da0073e9SAndroid Build Coastguard Worker zero = 0 124*da0073e9SAndroid Build Coastguard Worker one = zero + 1 125*da0073e9SAndroid Build Coastguard Worker slice2 = x[zero:one] 126*da0073e9SAndroid Build Coastguard Worker return slice1 + slice2 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker self.checkScript(t, (torch.zeros(3, 2, 3),)) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker def test_tuple_slicing(self): 131*da0073e9SAndroid Build Coastguard Worker def tuple_slice(a): 132*da0073e9SAndroid Build Coastguard Worker if bool(a): 133*da0073e9SAndroid Build Coastguard Worker b = (1, 2, 3, 4) 134*da0073e9SAndroid Build Coastguard Worker else: 135*da0073e9SAndroid Build Coastguard Worker b = (4, 3, 2, 1) 136*da0073e9SAndroid Build Coastguard Worker c = b[-4:4] 137*da0073e9SAndroid Build Coastguard Worker e = c[1:-1] 138*da0073e9SAndroid Build Coastguard Worker return e 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True) 141*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(tuple_slice) 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3)) 143*da0073e9SAndroid Build Coastguard Worker tuple_graph = scripted_fn.graph 144*da0073e9SAndroid Build Coastguard Worker slices = tuple_graph.findAllNodes("prim::TupleConstruct") 145*da0073e9SAndroid Build Coastguard Worker num_outputs = {len(x.output().type().elements()) for x in slices} 146*da0073e9SAndroid Build Coastguard Worker # there should be only one tupleSlice with length of 2 147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(num_outputs == {2}) 148*da0073e9SAndroid Build Coastguard Worker self.run_pass("lower_all_tuples", tuple_graph) 149*da0073e9SAndroid Build Coastguard Worker self.assertTrue("Tuple" not in str(tuple_graph)) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def test_module_list_slicing(self): 152*da0073e9SAndroid Build Coastguard Worker class Bar(torch.nn.Module): 153*da0073e9SAndroid Build Coastguard Worker def __init__(self, identifier: str): 154*da0073e9SAndroid Build Coastguard Worker super().__init__() 155*da0073e9SAndroid Build Coastguard Worker self.identifier = identifier 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def forward(self): 158*da0073e9SAndroid Build Coastguard Worker return 0 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 161*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 162*da0073e9SAndroid Build Coastguard Worker super().__init__() 163*da0073e9SAndroid Build Coastguard Worker module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")] 164*da0073e9SAndroid Build Coastguard Worker self.test = torch.nn.ModuleList(module_list) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker def forward(self): 167*da0073e9SAndroid Build Coastguard Worker return self.test[::-2], self.test[1:4:] 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker scripted_foo = torch.jit.script(Foo()) 170*da0073e9SAndroid Build Coastguard Worker result1, result2 = scripted_foo() 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(result1), 3) 173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result1[0].identifier, "E") 174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result1[1].identifier, "C") 175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result1[2].identifier, "A") 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(result2), 3) 178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result2[0].identifier, "B") 179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result2[1].identifier, "C") 180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result2[2].identifier, "D") 181