xref: /aosp_15_r20/external/pytorch/test/jit/test_slice.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom typing import List
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
11*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
17*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
18*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
19*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
20*da0073e9SAndroid Build Coastguard Worker        "instead."
21*da0073e9SAndroid Build Coastguard Worker    )
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker# Tests that Python slice class is supported in TorchScript
25*da0073e9SAndroid Build Coastguard Workerclass TestSlice(JitTestCase):
26*da0073e9SAndroid Build Coastguard Worker    def test_slice_kwarg(self):
27*da0073e9SAndroid Build Coastguard Worker        def slice_kwarg(x: List[int]):
28*da0073e9SAndroid Build Coastguard Worker            return x[slice(1, stop=2)]
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
31*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Slice does not accept any keyword arguments"
32*da0073e9SAndroid Build Coastguard Worker        ):
33*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(slice_kwarg)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    def test_slice_three_nones(self):
36*da0073e9SAndroid Build Coastguard Worker        def three_nones(x: List[int]):
37*da0073e9SAndroid Build Coastguard Worker            return x[slice(None, None, None)]
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        self.checkScript(three_nones, (range(10),))
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def test_slice_two_nones(self):
42*da0073e9SAndroid Build Coastguard Worker        def two_nones(x: List[int]):
43*da0073e9SAndroid Build Coastguard Worker            return x[slice(None, None)]
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        self.checkScript(two_nones, (range(10),))
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def test_slice_one_none(self):
48*da0073e9SAndroid Build Coastguard Worker        def one_none(x: List[int]):
49*da0073e9SAndroid Build Coastguard Worker            return x[slice(None)]
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker        self.checkScript(one_none, (range(10),))
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def test_slice_stop_only(self):
54*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
55*da0073e9SAndroid Build Coastguard Worker            return x[slice(5)]
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (range(10),))
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def test_slice_stop_only_with_nones(self):
60*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
61*da0073e9SAndroid Build Coastguard Worker            return x[slice(None, 5, None)]
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (range(10),))
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def test_slice_start_stop(self):
66*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
67*da0073e9SAndroid Build Coastguard Worker            return x[slice(1, 5)]
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (range(10),))
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def test_slice_start_stop_with_none(self):
72*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
73*da0073e9SAndroid Build Coastguard Worker            return x[slice(1, 5, None)]
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (range(10),))
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def test_slice_start_stop_step(self):
78*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
79*da0073e9SAndroid Build Coastguard Worker            return x[slice(0, 6, 2)]
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (range(10),))
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    def test_slice_string(self):
84*da0073e9SAndroid Build Coastguard Worker        def fn(x: str):
85*da0073e9SAndroid Build Coastguard Worker            return x[slice(None, 3, 1)]
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("foo_bar",))
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    def test_slice_tensor(self):
90*da0073e9SAndroid Build Coastguard Worker        def fn(x: torch.Tensor):
91*da0073e9SAndroid Build Coastguard Worker            return x[slice(None, 3, 1)]
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.ones(10),))
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def test_slice_tensor_multidim(self):
96*da0073e9SAndroid Build Coastguard Worker        def fn(x: torch.Tensor):
97*da0073e9SAndroid Build Coastguard Worker            return x[slice(None, 3, 1), 0]
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.ones((10, 10)),))
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def test_slice_tensor_multidim_with_dots(self):
102*da0073e9SAndroid Build Coastguard Worker        def fn(x: torch.Tensor):
103*da0073e9SAndroid Build Coastguard Worker            return x[slice(None, 3, 1), ...]
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.ones((10, 10)),))
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def test_slice_as_variable(self):
108*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
109*da0073e9SAndroid Build Coastguard Worker            a = slice(1)
110*da0073e9SAndroid Build Coastguard Worker            return x[a]
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (range(10),))
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    def test_slice_stop_clipped(self):
115*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
116*da0073e9SAndroid Build Coastguard Worker            return x[slice(1000)]
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (range(10),))
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def test_slice_dynamic_index(self):
121*da0073e9SAndroid Build Coastguard Worker        def t(x):
122*da0073e9SAndroid Build Coastguard Worker            slice1 = x[0:1]
123*da0073e9SAndroid Build Coastguard Worker            zero = 0
124*da0073e9SAndroid Build Coastguard Worker            one = zero + 1
125*da0073e9SAndroid Build Coastguard Worker            slice2 = x[zero:one]
126*da0073e9SAndroid Build Coastguard Worker            return slice1 + slice2
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker        self.checkScript(t, (torch.zeros(3, 2, 3),))
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    def test_tuple_slicing(self):
131*da0073e9SAndroid Build Coastguard Worker        def tuple_slice(a):
132*da0073e9SAndroid Build Coastguard Worker            if bool(a):
133*da0073e9SAndroid Build Coastguard Worker                b = (1, 2, 3, 4)
134*da0073e9SAndroid Build Coastguard Worker            else:
135*da0073e9SAndroid Build Coastguard Worker                b = (4, 3, 2, 1)
136*da0073e9SAndroid Build Coastguard Worker            c = b[-4:4]
137*da0073e9SAndroid Build Coastguard Worker            e = c[1:-1]
138*da0073e9SAndroid Build Coastguard Worker            return e
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
141*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(tuple_slice)
142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
143*da0073e9SAndroid Build Coastguard Worker        tuple_graph = scripted_fn.graph
144*da0073e9SAndroid Build Coastguard Worker        slices = tuple_graph.findAllNodes("prim::TupleConstruct")
145*da0073e9SAndroid Build Coastguard Worker        num_outputs = {len(x.output().type().elements()) for x in slices}
146*da0073e9SAndroid Build Coastguard Worker        # there should be only one tupleSlice with length of 2
147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(num_outputs == {2})
148*da0073e9SAndroid Build Coastguard Worker        self.run_pass("lower_all_tuples", tuple_graph)
149*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("Tuple" not in str(tuple_graph))
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def test_module_list_slicing(self):
152*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.nn.Module):
153*da0073e9SAndroid Build Coastguard Worker            def __init__(self, identifier: str):
154*da0073e9SAndroid Build Coastguard Worker                super().__init__()
155*da0073e9SAndroid Build Coastguard Worker                self.identifier = identifier
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker            def forward(self):
158*da0073e9SAndroid Build Coastguard Worker                return 0
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
161*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
162*da0073e9SAndroid Build Coastguard Worker                super().__init__()
163*da0073e9SAndroid Build Coastguard Worker                module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")]
164*da0073e9SAndroid Build Coastguard Worker                self.test = torch.nn.ModuleList(module_list)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker            def forward(self):
167*da0073e9SAndroid Build Coastguard Worker                return self.test[::-2], self.test[1:4:]
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        scripted_foo = torch.jit.script(Foo())
170*da0073e9SAndroid Build Coastguard Worker        result1, result2 = scripted_foo()
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(result1), 3)
173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result1[0].identifier, "E")
174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result1[1].identifier, "C")
175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result1[2].identifier, "A")
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(result2), 3)
178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result2[0].identifier, "B")
179*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result2[1].identifier, "C")
180*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result2[2].identifier, "D")
181