xref: /aosp_15_r20/external/pytorch/test/jit/test_device_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import unittest
4from itertools import product
5
6import torch
7from torch.jit._passes._property_propagation import apply_input_props_using_example
8from torch.testing._internal.common_utils import TEST_CUDA
9from torch.testing._internal.jit_utils import JitTestCase
10
11
12try:
13    from torchvision import models
14except ImportError:
15    models = None
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25class TestDeviceAnalysis(JitTestCase):
26    @classmethod
27    def setUpClass(cls):
28        cls.cpu = torch.device("cpu")
29        cls.cuda = torch.device("cuda")
30        cls.vulkan = torch.device("vulkan")
31        cls.mkldnn = torch.device(
32            "mkldnn"
33        )  # MKLDNN can't mix with other device types at all
34        cls.device_types = [cls.cpu, cls.cuda, cls.vulkan]
35
36    @staticmethod
37    def node_output_device(graph):
38        graph_out = list(graph.outputs())
39        assert len(graph_out) == 1
40        return graph_out[0].type().device()
41
42    def prop_device_on_graph(self, graph, example_devices, in_shapes=None):
43        graph_inputs = list(graph.inputs())
44        torch._C._jit_pass_erase_shape_information(graph)
45
46        self.assertEqual(len(graph_inputs), len(example_devices))
47        for graph_i, device_i in zip(graph_inputs, example_devices):
48            if device_i is not None:
49                graph_i.setType(graph_i.type().with_device(device_i))
50
51        if in_shapes:
52            for graph_i, shapes_i in zip(graph_inputs, in_shapes):
53                if shapes_i is not None:
54                    graph_i.setType(graph_i.type().with_sizes(shapes_i))
55
56            torch._C._jit_pass_propagate_shapes_on_graph(graph)
57
58        torch._C._jit_pass_propagate_device(graph)
59
60    def assert_device_equal(
61        self, fn, in_devices, expected_device, in_shapes=None, subtest_str=""
62    ):
63        with self.subTest(
64            f"In device: {in_devices}, expected: {expected_device}, \n {subtest_str}"
65        ):
66            graph = torch.jit.script(fn).graph
67            self.prop_device_on_graph(graph, in_devices, in_shapes)
68            actual_device = self.node_output_device(graph)
69
70            if expected_device is None or actual_device is None:
71                self.assertEqual(actual_device, expected_device)
72            else:
73                self.assertEqual(
74                    actual_device.type, expected_device.type, "Failed Verification"
75                )
76
77    def test_device_apply(self):
78        # Test if the device is properly applied to the input
79        def add_self(x):
80            return x + x
81
82        graph = torch.jit.script(add_self).graph
83        graph_input = next(graph.inputs())
84        graph_input.setType(graph_input.type().with_device(self.cpu))
85        # self.prop_device_on_graph(graph, [self.cpu])
86        self.assertEqual(graph_input.type().device(), self.cpu)
87
88    @unittest.skipIf(models is None, "Requires torchvision")
89    def test_mobilenet(self):
90        in_cpu = torch.randn(1, 3, 224, 224, device=self.cpu)
91        in_example = in_cpu
92
93        expected_device = self.cpu
94        m = torch.jit.script(models.mobilenet_v3_small())
95        m.eval()
96        graph = torch.jit.freeze(m).graph
97        # torch._C._jit_pass_erase_shape_information(graph)
98        apply_input_props_using_example(graph, in_example)
99        torch._C._jit_pass_propagate_shapes_on_graph(graph)
100        torch._C._jit_pass_propagate_device(graph)
101
102        actual_device = self.node_output_device(graph)
103
104        if expected_device is None or actual_device is None:
105            self.assertEqual(actual_device, expected_device)
106        else:
107            self.assertEqual(
108                actual_device.type, expected_device.type, "Failed Verification"
109            )
110
111    def test_simple(self):
112        def add_self(x):
113            return x + x
114
115        def relu_(x):
116            return torch.nn.functional.relu_(x)
117
118        functions = [add_self, relu_]
119
120        for in_device, fn in product(self.device_types, functions):
121            self.assert_device_equal(fn, [in_device], in_device)
122
123    def test_set_dtype(self):
124        def set_device(x):
125            return x.to("cpu")
126
127        for in_device in self.device_types:
128            self.assert_device_equal(set_device, [in_device], self.cpu)
129
130    def test_device_arg(self):
131        # Test that no device gets propagated when arg is passed in
132        def set_device(x, device_name: torch.device):
133            return x.to(device=device_name)
134
135        for in_device in self.device_types:
136            self.assert_device_equal(set_device, [in_device, None], None)
137
138    def test_tensor_as_fns(self):
139        def view_as_fn(x, y):
140            return x.view_as(y)
141
142        def expand_as_fn(x, y):
143            return x.expand_as(y)
144
145        def reshape_as_fn(x, y):
146            return x.reshape_as(y)
147
148        for test_fn in [view_as_fn, expand_as_fn, reshape_as_fn]:
149            self.assert_device_equal(test_fn, [self.cpu, self.cpu], self.cpu)
150            self.assert_device_equal(test_fn, [self.cuda, None], self.cuda)
151            self.assert_device_equal(test_fn, [None, self.mkldnn], None)
152
153        def type_as_fn(x, y):
154            return x.type_as(y)
155
156        self.assert_device_equal(type_as_fn, [self.cpu, self.cpu], self.cpu)
157        self.assert_device_equal(type_as_fn, [self.cuda, None], None)
158        self.assert_device_equal(type_as_fn, [None, self.mkldnn], self.mkldnn)
159
160    def zerodim_test_core(self, device_pairs):
161        # Test the support of zerodim tensors with non-zerodim tensors
162        def mul(x, y):
163            return x * y
164
165        def add(x, y):
166            return x + y
167
168        fns = [mul, add]
169
170        input_shapes = [
171            ((1, 2, 2), (2, 2)),  # Different dim, non-zerodim
172            ((1, 2, 2), ()),  # one zerodim
173            ((), ()),  # both zerodim
174        ]
175
176        for fn, shapes, devices in product(fns, input_shapes, device_pairs):
177            subtest_str = f"{fn.__name__} \n shapes: {shapes}, \n devices: {devices}"
178            in0 = torch.rand(shapes[0], device=devices[0])
179            in1 = torch.rand(shapes[1], device=devices[1])
180
181            try:
182                out = fn(in0, in1)
183            except Exception as e:
184                # Don't expect eager failures for CPU zerodim tensors
185                for i in range(len(devices)):
186                    if shapes[i] == () and devices[i] == self.cpu:
187                        raise e
188
189                # only expect eager failures on different devices
190                if devices[0] == devices[1]:
191                    raise e
192
193                # Expect result device to be None for the failure cases.
194                self.assert_device_equal(fn, devices, None, shapes, subtest_str)
195                continue
196
197            self.assert_device_equal(fn, devices, out.device, shapes, subtest_str)
198
199            # Test that without shapes, we either get the same device or None for the device
200            # Aka that the code is convservative for tensor shapes.
201            graph = torch.jit.script(fn).graph
202            self.prop_device_on_graph(graph, devices)
203            actual_device = self.node_output_device(graph)
204            self.assertTrue(
205                (actual_device is None) or (actual_device.type == out.device.type)
206            )
207
208    def test_zerodim_cpu(self):
209        # Allow for minimal testing locally
210        self.zerodim_test_core([(self.cpu, self.cpu)])
211
212    def test_zerodim_no_device(self):
213        # If device is missing, you should never be able to infer device type.
214        def mul(x, y):
215            return x * y
216
217        def add(x, y):
218            return x + y
219
220        fns = [mul, add]
221
222        device_pairs = [
223            (self.cpu, None),
224            (None, self.cpu),
225            (None, None),
226        ]
227
228        input_shapes = [
229            ((1, 2, 2), (2, 2)),  # Different dim, non-zerodim
230            ((1, 2, 2), ()),  # one zerodim
231            ((), ()),  # both zerodim
232        ]
233
234        for fn, shapes, devices in product(fns, input_shapes, device_pairs):
235            self.assert_device_equal(fn, devices, None, shapes)
236
237    @unittest.skipIf(not TEST_CUDA, "No CUDA")
238    def test_zerodim_gpu(self):
239        device_pairs = [
240            (self.cpu, self.cuda),
241            (self.cuda, self.cpu),
242            (self.cuda, self.cuda),
243        ]
244        self.zerodim_test_core(device_pairs)
245
246    def test_custom_device_op(self):
247        # Test both of the custom functions and check that the devicetype is
248        # correctly applied
249        def set_cuda(x):
250            return x.cuda()
251
252        def set_cpu(x):
253            return x.cpu()
254
255        def set_mkldnn(x):
256            return x.to_mkldnn()
257
258        device_pairs = (
259            (set_cuda, self.cuda),
260            (set_cpu, self.cpu),
261            (set_mkldnn, self.mkldnn),
262        )
263
264        for fn, out_device in device_pairs:
265            for in_device in self.device_types:
266                self.assert_device_equal(fn, [in_device], out_device)
267
268    def test_device_if_propagation(self):
269        def test_fn(x, y, z: bool):
270            if z:
271                return x + 3
272            else:
273                return y * 2
274
275        self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu)
276        self.assert_device_equal(test_fn, [self.mkldnn, self.mkldnn, None], self.mkldnn)
277        self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None)
278
279    def test_loop_simple(self):
280        def test_fn(x, y, z: int):
281            for _ in range(z):
282                y = x
283            return y
284
285        self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu)
286        self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None)
287        self.assert_device_equal(test_fn, [self.cpu, None, None], None)
288
289    def test_loop_device_change(self):
290        def test_fn(x, z: int):
291            for _ in range(z):
292                x = x.cuda()
293            return x
294
295        self.assert_device_equal(test_fn, [self.cpu, None], None)
296        self.assert_device_equal(test_fn, [self.cuda, None], self.cuda)
297        self.assert_device_equal(test_fn, [None, None], None)
298
299    def test_while_change(self):
300        def test_fn(x, z: int):
301            while z > 0:
302                x = x.cuda()
303                z = 0
304            return x
305
306        self.assert_device_equal(test_fn, [self.cpu, None], None)
307        self.assert_device_equal(test_fn, [self.cuda, None], self.cuda)
308        self.assert_device_equal(test_fn, [None, None], None)
309
310    def test_nested_loops(self):
311        def test_fn(x, z: int):
312            for i in range(z):
313                x = x.cpu()
314                for _ in range(i):
315                    x = x + 1
316
317            return x
318
319        self.assert_device_equal(test_fn, [self.cpu, None], self.cpu)
320        self.assert_device_equal(test_fn, [self.cuda, None], None)
321        self.assert_device_equal(test_fn, [None, None], None)
322
323    def test_if_loop_mix(self):
324        def test_fn(x, y, z: bool, a: bool):
325            c = x
326            while a:
327                if z:
328                    c = x + 3
329                else:
330                    c = y * 2
331                a = False
332            return c
333
334        self.assert_device_equal(test_fn, [self.cpu, self.cpu, None, None], self.cpu)
335        self.assert_device_equal(
336            test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn
337        )
338        self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None)
339