xref: /aosp_15_r20/external/pytorch/test/jit/test_slice.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import List
6
7import torch
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16if __name__ == "__main__":
17    raise RuntimeError(
18        "This test file is not meant to be run directly, use:\n\n"
19        "\tpython test/test_jit.py TESTNAME\n\n"
20        "instead."
21    )
22
23
24# Tests that Python slice class is supported in TorchScript
25class TestSlice(JitTestCase):
26    def test_slice_kwarg(self):
27        def slice_kwarg(x: List[int]):
28            return x[slice(1, stop=2)]
29
30        with self.assertRaisesRegex(
31            RuntimeError, "Slice does not accept any keyword arguments"
32        ):
33            torch.jit.script(slice_kwarg)
34
35    def test_slice_three_nones(self):
36        def three_nones(x: List[int]):
37            return x[slice(None, None, None)]
38
39        self.checkScript(three_nones, (range(10),))
40
41    def test_slice_two_nones(self):
42        def two_nones(x: List[int]):
43            return x[slice(None, None)]
44
45        self.checkScript(two_nones, (range(10),))
46
47    def test_slice_one_none(self):
48        def one_none(x: List[int]):
49            return x[slice(None)]
50
51        self.checkScript(one_none, (range(10),))
52
53    def test_slice_stop_only(self):
54        def fn(x: List[int]):
55            return x[slice(5)]
56
57        self.checkScript(fn, (range(10),))
58
59    def test_slice_stop_only_with_nones(self):
60        def fn(x: List[int]):
61            return x[slice(None, 5, None)]
62
63        self.checkScript(fn, (range(10),))
64
65    def test_slice_start_stop(self):
66        def fn(x: List[int]):
67            return x[slice(1, 5)]
68
69        self.checkScript(fn, (range(10),))
70
71    def test_slice_start_stop_with_none(self):
72        def fn(x: List[int]):
73            return x[slice(1, 5, None)]
74
75        self.checkScript(fn, (range(10),))
76
77    def test_slice_start_stop_step(self):
78        def fn(x: List[int]):
79            return x[slice(0, 6, 2)]
80
81        self.checkScript(fn, (range(10),))
82
83    def test_slice_string(self):
84        def fn(x: str):
85            return x[slice(None, 3, 1)]
86
87        self.checkScript(fn, ("foo_bar",))
88
89    def test_slice_tensor(self):
90        def fn(x: torch.Tensor):
91            return x[slice(None, 3, 1)]
92
93        self.checkScript(fn, (torch.ones(10),))
94
95    def test_slice_tensor_multidim(self):
96        def fn(x: torch.Tensor):
97            return x[slice(None, 3, 1), 0]
98
99        self.checkScript(fn, (torch.ones((10, 10)),))
100
101    def test_slice_tensor_multidim_with_dots(self):
102        def fn(x: torch.Tensor):
103            return x[slice(None, 3, 1), ...]
104
105        self.checkScript(fn, (torch.ones((10, 10)),))
106
107    def test_slice_as_variable(self):
108        def fn(x: List[int]):
109            a = slice(1)
110            return x[a]
111
112        self.checkScript(fn, (range(10),))
113
114    def test_slice_stop_clipped(self):
115        def fn(x: List[int]):
116            return x[slice(1000)]
117
118        self.checkScript(fn, (range(10),))
119
120    def test_slice_dynamic_index(self):
121        def t(x):
122            slice1 = x[0:1]
123            zero = 0
124            one = zero + 1
125            slice2 = x[zero:one]
126            return slice1 + slice2
127
128        self.checkScript(t, (torch.zeros(3, 2, 3),))
129
130    def test_tuple_slicing(self):
131        def tuple_slice(a):
132            if bool(a):
133                b = (1, 2, 3, 4)
134            else:
135                b = (4, 3, 2, 1)
136            c = b[-4:4]
137            e = c[1:-1]
138            return e
139
140        self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
141        scripted_fn = torch.jit.script(tuple_slice)
142        self.assertEqual(scripted_fn(torch.tensor(1)), (2, 3))
143        tuple_graph = scripted_fn.graph
144        slices = tuple_graph.findAllNodes("prim::TupleConstruct")
145        num_outputs = {len(x.output().type().elements()) for x in slices}
146        # there should be only one tupleSlice with length of 2
147        self.assertTrue(num_outputs == {2})
148        self.run_pass("lower_all_tuples", tuple_graph)
149        self.assertTrue("Tuple" not in str(tuple_graph))
150
151    def test_module_list_slicing(self):
152        class Bar(torch.nn.Module):
153            def __init__(self, identifier: str):
154                super().__init__()
155                self.identifier = identifier
156
157            def forward(self):
158                return 0
159
160        class Foo(torch.nn.Module):
161            def __init__(self) -> None:
162                super().__init__()
163                module_list = [Bar("A"), Bar("B"), Bar("C"), Bar("D"), Bar("E")]
164                self.test = torch.nn.ModuleList(module_list)
165
166            def forward(self):
167                return self.test[::-2], self.test[1:4:]
168
169        scripted_foo = torch.jit.script(Foo())
170        result1, result2 = scripted_foo()
171
172        self.assertEqual(len(result1), 3)
173        self.assertEqual(result1[0].identifier, "E")
174        self.assertEqual(result1[1].identifier, "C")
175        self.assertEqual(result1[2].identifier, "A")
176
177        self.assertEqual(len(result2), 3)
178        self.assertEqual(result2[0].identifier, "B")
179        self.assertEqual(result2[1].identifier, "C")
180        self.assertEqual(result2[2].identifier, "D")
181