1# Owner(s): ["oncall: jit"] 2 3import unittest 4from itertools import product 5 6import torch 7from torch.jit._passes._property_propagation import apply_input_props_using_example 8from torch.testing._internal.common_utils import TEST_CUDA 9from torch.testing._internal.jit_utils import JitTestCase 10 11 12try: 13 from torchvision import models 14except ImportError: 15 models = None 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25class TestDeviceAnalysis(JitTestCase): 26 @classmethod 27 def setUpClass(cls): 28 cls.cpu = torch.device("cpu") 29 cls.cuda = torch.device("cuda") 30 cls.vulkan = torch.device("vulkan") 31 cls.mkldnn = torch.device( 32 "mkldnn" 33 ) # MKLDNN can't mix with other device types at all 34 cls.device_types = [cls.cpu, cls.cuda, cls.vulkan] 35 36 @staticmethod 37 def node_output_device(graph): 38 graph_out = list(graph.outputs()) 39 assert len(graph_out) == 1 40 return graph_out[0].type().device() 41 42 def prop_device_on_graph(self, graph, example_devices, in_shapes=None): 43 graph_inputs = list(graph.inputs()) 44 torch._C._jit_pass_erase_shape_information(graph) 45 46 self.assertEqual(len(graph_inputs), len(example_devices)) 47 for graph_i, device_i in zip(graph_inputs, example_devices): 48 if device_i is not None: 49 graph_i.setType(graph_i.type().with_device(device_i)) 50 51 if in_shapes: 52 for graph_i, shapes_i in zip(graph_inputs, in_shapes): 53 if shapes_i is not None: 54 graph_i.setType(graph_i.type().with_sizes(shapes_i)) 55 56 torch._C._jit_pass_propagate_shapes_on_graph(graph) 57 58 torch._C._jit_pass_propagate_device(graph) 59 60 def assert_device_equal( 61 self, fn, in_devices, expected_device, in_shapes=None, subtest_str="" 62 ): 63 with self.subTest( 64 f"In device: {in_devices}, expected: {expected_device}, \n {subtest_str}" 65 ): 66 graph = torch.jit.script(fn).graph 67 self.prop_device_on_graph(graph, in_devices, in_shapes) 68 actual_device = self.node_output_device(graph) 69 70 if expected_device is None or actual_device is None: 71 self.assertEqual(actual_device, expected_device) 72 else: 73 self.assertEqual( 74 actual_device.type, expected_device.type, "Failed Verification" 75 ) 76 77 def test_device_apply(self): 78 # Test if the device is properly applied to the input 79 def add_self(x): 80 return x + x 81 82 graph = torch.jit.script(add_self).graph 83 graph_input = next(graph.inputs()) 84 graph_input.setType(graph_input.type().with_device(self.cpu)) 85 # self.prop_device_on_graph(graph, [self.cpu]) 86 self.assertEqual(graph_input.type().device(), self.cpu) 87 88 @unittest.skipIf(models is None, "Requires torchvision") 89 def test_mobilenet(self): 90 in_cpu = torch.randn(1, 3, 224, 224, device=self.cpu) 91 in_example = in_cpu 92 93 expected_device = self.cpu 94 m = torch.jit.script(models.mobilenet_v3_small()) 95 m.eval() 96 graph = torch.jit.freeze(m).graph 97 # torch._C._jit_pass_erase_shape_information(graph) 98 apply_input_props_using_example(graph, in_example) 99 torch._C._jit_pass_propagate_shapes_on_graph(graph) 100 torch._C._jit_pass_propagate_device(graph) 101 102 actual_device = self.node_output_device(graph) 103 104 if expected_device is None or actual_device is None: 105 self.assertEqual(actual_device, expected_device) 106 else: 107 self.assertEqual( 108 actual_device.type, expected_device.type, "Failed Verification" 109 ) 110 111 def test_simple(self): 112 def add_self(x): 113 return x + x 114 115 def relu_(x): 116 return torch.nn.functional.relu_(x) 117 118 functions = [add_self, relu_] 119 120 for in_device, fn in product(self.device_types, functions): 121 self.assert_device_equal(fn, [in_device], in_device) 122 123 def test_set_dtype(self): 124 def set_device(x): 125 return x.to("cpu") 126 127 for in_device in self.device_types: 128 self.assert_device_equal(set_device, [in_device], self.cpu) 129 130 def test_device_arg(self): 131 # Test that no device gets propagated when arg is passed in 132 def set_device(x, device_name: torch.device): 133 return x.to(device=device_name) 134 135 for in_device in self.device_types: 136 self.assert_device_equal(set_device, [in_device, None], None) 137 138 def test_tensor_as_fns(self): 139 def view_as_fn(x, y): 140 return x.view_as(y) 141 142 def expand_as_fn(x, y): 143 return x.expand_as(y) 144 145 def reshape_as_fn(x, y): 146 return x.reshape_as(y) 147 148 for test_fn in [view_as_fn, expand_as_fn, reshape_as_fn]: 149 self.assert_device_equal(test_fn, [self.cpu, self.cpu], self.cpu) 150 self.assert_device_equal(test_fn, [self.cuda, None], self.cuda) 151 self.assert_device_equal(test_fn, [None, self.mkldnn], None) 152 153 def type_as_fn(x, y): 154 return x.type_as(y) 155 156 self.assert_device_equal(type_as_fn, [self.cpu, self.cpu], self.cpu) 157 self.assert_device_equal(type_as_fn, [self.cuda, None], None) 158 self.assert_device_equal(type_as_fn, [None, self.mkldnn], self.mkldnn) 159 160 def zerodim_test_core(self, device_pairs): 161 # Test the support of zerodim tensors with non-zerodim tensors 162 def mul(x, y): 163 return x * y 164 165 def add(x, y): 166 return x + y 167 168 fns = [mul, add] 169 170 input_shapes = [ 171 ((1, 2, 2), (2, 2)), # Different dim, non-zerodim 172 ((1, 2, 2), ()), # one zerodim 173 ((), ()), # both zerodim 174 ] 175 176 for fn, shapes, devices in product(fns, input_shapes, device_pairs): 177 subtest_str = f"{fn.__name__} \n shapes: {shapes}, \n devices: {devices}" 178 in0 = torch.rand(shapes[0], device=devices[0]) 179 in1 = torch.rand(shapes[1], device=devices[1]) 180 181 try: 182 out = fn(in0, in1) 183 except Exception as e: 184 # Don't expect eager failures for CPU zerodim tensors 185 for i in range(len(devices)): 186 if shapes[i] == () and devices[i] == self.cpu: 187 raise e 188 189 # only expect eager failures on different devices 190 if devices[0] == devices[1]: 191 raise e 192 193 # Expect result device to be None for the failure cases. 194 self.assert_device_equal(fn, devices, None, shapes, subtest_str) 195 continue 196 197 self.assert_device_equal(fn, devices, out.device, shapes, subtest_str) 198 199 # Test that without shapes, we either get the same device or None for the device 200 # Aka that the code is convservative for tensor shapes. 201 graph = torch.jit.script(fn).graph 202 self.prop_device_on_graph(graph, devices) 203 actual_device = self.node_output_device(graph) 204 self.assertTrue( 205 (actual_device is None) or (actual_device.type == out.device.type) 206 ) 207 208 def test_zerodim_cpu(self): 209 # Allow for minimal testing locally 210 self.zerodim_test_core([(self.cpu, self.cpu)]) 211 212 def test_zerodim_no_device(self): 213 # If device is missing, you should never be able to infer device type. 214 def mul(x, y): 215 return x * y 216 217 def add(x, y): 218 return x + y 219 220 fns = [mul, add] 221 222 device_pairs = [ 223 (self.cpu, None), 224 (None, self.cpu), 225 (None, None), 226 ] 227 228 input_shapes = [ 229 ((1, 2, 2), (2, 2)), # Different dim, non-zerodim 230 ((1, 2, 2), ()), # one zerodim 231 ((), ()), # both zerodim 232 ] 233 234 for fn, shapes, devices in product(fns, input_shapes, device_pairs): 235 self.assert_device_equal(fn, devices, None, shapes) 236 237 @unittest.skipIf(not TEST_CUDA, "No CUDA") 238 def test_zerodim_gpu(self): 239 device_pairs = [ 240 (self.cpu, self.cuda), 241 (self.cuda, self.cpu), 242 (self.cuda, self.cuda), 243 ] 244 self.zerodim_test_core(device_pairs) 245 246 def test_custom_device_op(self): 247 # Test both of the custom functions and check that the devicetype is 248 # correctly applied 249 def set_cuda(x): 250 return x.cuda() 251 252 def set_cpu(x): 253 return x.cpu() 254 255 def set_mkldnn(x): 256 return x.to_mkldnn() 257 258 device_pairs = ( 259 (set_cuda, self.cuda), 260 (set_cpu, self.cpu), 261 (set_mkldnn, self.mkldnn), 262 ) 263 264 for fn, out_device in device_pairs: 265 for in_device in self.device_types: 266 self.assert_device_equal(fn, [in_device], out_device) 267 268 def test_device_if_propagation(self): 269 def test_fn(x, y, z: bool): 270 if z: 271 return x + 3 272 else: 273 return y * 2 274 275 self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu) 276 self.assert_device_equal(test_fn, [self.mkldnn, self.mkldnn, None], self.mkldnn) 277 self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None) 278 279 def test_loop_simple(self): 280 def test_fn(x, y, z: int): 281 for _ in range(z): 282 y = x 283 return y 284 285 self.assert_device_equal(test_fn, [self.cpu, self.cpu, None], self.cpu) 286 self.assert_device_equal(test_fn, [self.cpu, self.cuda, None], None) 287 self.assert_device_equal(test_fn, [self.cpu, None, None], None) 288 289 def test_loop_device_change(self): 290 def test_fn(x, z: int): 291 for _ in range(z): 292 x = x.cuda() 293 return x 294 295 self.assert_device_equal(test_fn, [self.cpu, None], None) 296 self.assert_device_equal(test_fn, [self.cuda, None], self.cuda) 297 self.assert_device_equal(test_fn, [None, None], None) 298 299 def test_while_change(self): 300 def test_fn(x, z: int): 301 while z > 0: 302 x = x.cuda() 303 z = 0 304 return x 305 306 self.assert_device_equal(test_fn, [self.cpu, None], None) 307 self.assert_device_equal(test_fn, [self.cuda, None], self.cuda) 308 self.assert_device_equal(test_fn, [None, None], None) 309 310 def test_nested_loops(self): 311 def test_fn(x, z: int): 312 for i in range(z): 313 x = x.cpu() 314 for _ in range(i): 315 x = x + 1 316 317 return x 318 319 self.assert_device_equal(test_fn, [self.cpu, None], self.cpu) 320 self.assert_device_equal(test_fn, [self.cuda, None], None) 321 self.assert_device_equal(test_fn, [None, None], None) 322 323 def test_if_loop_mix(self): 324 def test_fn(x, y, z: bool, a: bool): 325 c = x 326 while a: 327 if z: 328 c = x + 3 329 else: 330 c = y * 2 331 a = False 332 return c 333 334 self.assert_device_equal(test_fn, [self.cpu, self.cpu, None, None], self.cpu) 335 self.assert_device_equal( 336 test_fn, [self.mkldnn, self.mkldnn, None, None], self.mkldnn 337 ) 338 self.assert_device_equal(test_fn, [self.cpu, self.cuda, None, None], None) 339