1# Owner(s): ["oncall: jit"] 2 3from itertools import product 4from typing import Tuple 5from unittest.case import expectedFailure 6 7import torch 8from torch import complex32, float32, float64, int32, int64 9from torch.jit._passes import _property_propagation 10from torch.testing._internal.common_device_type import ( 11 instantiate_device_type_tests, 12 ops, 13) 14from torch.testing._internal.common_methods_invocations import ( 15 op_db, 16 sample_inputs_adaptive_avg_pool2d, 17 sample_inputs_conv2d, 18 SampleInput, 19) 20from torch.testing._internal.common_utils import first_sample, set_default_dtype 21from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn 22from torch.testing._internal.jit_utils import JitTestCase 23 24 25""" 26Dtype Analysis relies on symbolic shape analysis, which is still in beta 27""" 28 29 30if __name__ == "__main__": 31 raise RuntimeError( 32 "This test file is not meant to be run directly, use:\n\n" 33 "\tpython test/test_jit.py TESTNAME\n\n" 34 "instead." 35 ) 36 37 38custom_rules_works_list = { 39 "nn.functional.adaptive_avg_pool1d", 40 "nn.functional.adaptive_avg_pool2d", 41 "nn.functional.adaptive_avg_pool3d", 42 "nn.functional.adaptive_max_pool1d", 43 "nn.functional.adaptive_max_pool2d", 44 "avg_pool1d", 45 "avg_pool3d", 46 "conv_transpose2d", 47 "conv1d", 48 "conv2d", 49 "hardswish", 50 "avg_pool2d", 51 "max_pool1d", 52 "max_pool2d", 53 "max_pool3d", 54 "nn.functional.prelu", 55 "batch_norm", 56} 57 58custom_rules_expected_failure_list = { 59 # create_traced_fn generates prim::NumToTensor nodes in graph (not supported yet) 60 "nn.functional.adaptive_max_pool3d", 61} 62 63# These ops seem to not be in opinfos 64custom_rules_not_tested_list = [ 65 "conv3d", 66 "conv_tbc", 67 "conv_transpose1d", 68 "conv_transpose3d", 69 "convolution", 70 "_convolution", 71 "max_unpool2d", 72 "max_unpool3d", 73 "reflection_pad1d", 74 "reflection_pad2d", 75 "reflection_pad3d", 76 "replication_pad1d", 77 "replication_pad2d", 78 "replication_pad3d", 79 "upsample_bilinear2d", 80 "upsample_linear1d", 81 "upsample_nearest1d", 82 "upsample_nearest2d", 83 "upsample_nearest3d", 84 "upsample_trilinear3d", 85 "flatten", 86] 87 88 89class TestDtypeBase(JitTestCase): 90 SCALAR = "SCALAR" # To mark unary vs 0 dim tensor 91 92 def setUp(self): 93 self.prev_symbolic_shapes_test_enabled = ( 94 torch._C._jit_symbolic_shapes_test_mode_enabled() 95 ) 96 torch._C._jit_set_symbolic_shapes_test_mode(True) 97 98 def tearDown(self): 99 torch._C._jit_set_symbolic_shapes_test_mode( 100 self.prev_symbolic_shapes_test_enabled 101 ) 102 103 @staticmethod 104 def node_output_dtypes(graph): 105 dtypes = [] 106 for out in graph.outputs(): 107 if isinstance(out.type(), torch._C.TensorType): 108 dtypes.append(out.type().dtype()) 109 else: 110 dtypes.append(None) 111 return dtypes 112 113 @staticmethod 114 def node_output_dtype_single(graph): 115 dtypes = TestDtypeBase.node_output_dtypes(graph) 116 assert len(dtypes) == 1 117 return dtypes[0] 118 119 def prop_dtype_on_graph(self, graph, example_inputs): 120 # We need to clear shape information because torch.jit.script 121 # will return a cached graph if the function is scripted twice. 122 torch._C._jit_pass_erase_shape_information(graph) 123 _property_propagation.apply_input_props_using_example(graph, example_inputs) 124 torch._C._jit_pass_propagate_shapes_on_graph(graph) 125 torch._C._jit_pass_propagate_dtype(graph) 126 127 def assert_dtype_equal(self, fn, in_shapes, in_dtypes): 128 inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)] 129 try: 130 self.assert_dtype_equal_custom_args(fn, inputs) 131 except Exception as e: 132 fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}" 133 raise AssertionError(fail_text) from e 134 135 def assert_dtype_equal_custom_args(self, fn, args): 136 try: 137 # Eager execution 138 expected_res = fn(*args) 139 except RuntimeError as e: 140 return 141 142 expected_dtype = expected_res.dtype 143 144 # Run the Dtype Analysis 145 graph = torch.jit.script(fn).graph # Note this is a cached graph 146 self.prop_dtype_on_graph(graph, args) 147 actual_dtype = self.node_output_dtype_single(graph) 148 149 self.assertEqual(actual_dtype, expected_dtype, "Failed Verification") 150 151 def get_rand_tensor(self, shape, dtype): 152 if shape is self.SCALAR: 153 if dtype is float32: 154 return 1.1 155 elif dtype is int64: 156 return 2 157 else: 158 raise RuntimeError( 159 "Testing of scalars only supported for fp32 and int64" 160 ) 161 162 if dtype in (int32, int64): 163 rand_tensor = torch.randint(0, 10, shape, dtype=dtype) 164 else: 165 rand_tensor = torch.rand(shape, dtype=dtype) 166 167 # Sanity check! 168 169 self.assertEqual(rand_tensor.dtype, dtype) 170 return rand_tensor 171 172 173class TestDtypeAnalysis(TestDtypeBase): 174 def test_unary(self): 175 # Testing the Unary Implementation that uses metatensors 176 177 def relu_inplace(x): 178 return x.relu_() 179 180 def log(x): 181 return torch.log(x) 182 183 functions = [relu_inplace, log] 184 185 input_shapes = [ 186 ((2, 2),), # Simple Case 187 ((0, 2),), # Size 0 Tensor 188 ((),), # zerodim 189 ] 190 191 input_dtypes = [ 192 (float32,), # Simple Case 193 (int64,), # Test how some unary ops implicitly convert to float 194 (complex32,), # Show we can handle complex vals as well 195 ] 196 197 for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes): 198 self.assert_dtype_equal(fn, in_shapes, in_dtypes) 199 200 def test_binary_tensors(self): 201 # Testing using Metatensors 202 def add(x, y): 203 return x + y 204 205 def div(x, y): 206 return x / y 207 208 functions = [add, div] 209 210 input_shapes = [ 211 ((1, 1, 2), (1, 2)), # Different Dim, non-zerodim 212 ((), (1, 2)), # One zerodim 213 ((1, 2), ()), # Other zerodim 214 ((2, 0, 3), (1, 3)), # Test a tensor with a dim of 0 215 ((), ()), # both zerodim 216 ] 217 218 input_dtypes = [ 219 (float32, float32), # Simple Case 220 (int32, int64), # Size Promotion (compliated case for 0dim tensors) 221 (float32, int32), # type Promotion 222 (int64, float32), # Type promotion with size change 223 (float64, complex32), # Show we can handle complex vals as well 224 ] 225 226 for fn, in_shapes, in_dtypes in product(functions, input_shapes, input_dtypes): 227 self.assert_dtype_equal(fn, in_shapes, in_dtypes) 228 229 def test_binary_scalar(self): 230 # Test the mixing of scalar and non-scalar args 231 232 input_shapes = [ 233 ((2, 2), self.SCALAR), # Non-Zerodim vs scalar 234 ((), self.SCALAR), # Zerodim vs scalar 235 # Scalar vs Scalar is automatically inferred. 236 ] 237 238 input_dtypes = [ 239 (float32, float32), # Simple Case 240 (int32, int64), # Size Promotion (compliated case for 0dim tensors) 241 (int32, float32), # type Promotion 242 ] 243 244 with set_default_dtype(float32): 245 for in_shapes, in_dtypes in product(input_shapes, input_dtypes): 246 scalar_type = in_dtypes[1] 247 248 if scalar_type == float32: 249 250 def add(x, y: float): 251 return x + y 252 253 else: 254 255 def add(x, y: int): 256 return x + y 257 258 self.assert_dtype_equal(add, in_shapes, in_dtypes) 259 260 def test_custom_rules(self): 261 # Test some of the ops that are not covered by Metatensors 262 263 # Note that unlike the Conv2d module, the function conv2d 264 # does not take dtype/device arguments. 265 266 def conv2d_fn(input, weight, bias): 267 return torch.nn.functional.conv2d(input, weight, bias) 268 269 def adaptive_avg_pool2d_fn(input, output_size: Tuple[int]): 270 return torch._C._nn.adaptive_avg_pool2d(input, output_size) 271 272 for fn, inputs_fn in ( 273 (conv2d_fn, sample_inputs_conv2d), 274 (adaptive_avg_pool2d_fn, sample_inputs_adaptive_avg_pool2d), 275 ): 276 for dtype in (torch.int8, torch.float64): 277 # Gets default version for conv2d 278 sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[ 279 -1 280 ] 281 input_args = [sample_input.input, *sample_input.args] 282 self.assert_dtype_equal_custom_args(fn, input_args) 283 284 def test_conv_no_mixed_args(self): 285 def conv2d_fn(input, weight, bias): 286 return torch.nn.functional.conv2d(input, weight, bias) 287 288 # Now make sure that conv2d doesn't support mixed args 289 conv_ins = sample_inputs_conv2d(None, "cpu", torch.float, False) 290 conv_in = list(conv_ins)[-1] 291 weight, bias = conv_in.args 292 weight = weight.type(torch.long) 293 294 with self.assertRaises(RuntimeError): 295 conv2d_fn(conv_in.input, weight, bias) 296 297 # Check that we also don't propagate 298 graph = torch.jit.script(conv2d_fn).graph # Note this is a cached graph 299 self.prop_dtype_on_graph(graph, [conv_in.input, weight, bias]) 300 actual_dtype = self.node_output_dtype_single(graph) 301 self.assertEqual(actual_dtype, None) 302 303 def test_combined(self): 304 # Test a case with both custom rules and metatensors 305 306 def func(input, weight, bias, y): 307 conv_out = torch.nn.functional.conv2d(input, weight, bias) 308 conv_2 = conv_out + y 309 flattened = torch.flatten(conv_2, start_dim=2) 310 add_res = flattened + y 311 return add_res 312 313 conv_ins = sample_inputs_conv2d(None, "cpu", torch.int8, False) 314 conv_in = list(conv_ins)[-1] 315 y_val = torch.rand((1,), dtype=torch.float32) 316 input_args = [conv_in.input, *conv_in.args, y_val] 317 self.assert_dtype_equal_custom_args(func, input_args) 318 319 320class TestDtypeCustomRules(TestDtypeBase): 321 def assert_output_dtype_equal(self, expected_res, prop_graph): 322 actual_dtype = self.node_output_dtypes(prop_graph) 323 if len(actual_dtype) == 1: 324 # For len=1, there is no tuple packing for expected_res. 325 self.assert_tensor_dtype_equal(expected_res, actual_dtype[0]) 326 else: 327 self.assertEqual(len(expected_res), len(actual_dtype)) 328 for expected, actual in zip(expected_res, actual_dtype): 329 self.assert_tensor_dtype_equal(expected, actual) 330 331 def assert_tensor_dtype_equal(self, tensor_output, graph_dtype): 332 if not isinstance(tensor_output, torch.Tensor): 333 return 334 self.assertEqual(tensor_output.dtype, graph_dtype) 335 336 def custom_rules_test_base(self, device, dtype, op, allow_eager_fail=False): 337 try: 338 samples = op.sample_inputs(device, dtype, requires_grad=False) 339 sample_input = first_sample(self, samples) 340 input_args = [sample_input.input, *sample_input.args] 341 expected_res = op(*input_args, **sample_input.kwargs) 342 343 except Exception as e: 344 if allow_eager_fail: 345 return 346 else: 347 raise e 348 349 func = op.get_op() 350 traced_fn = create_traced_fn(self, func) 351 352 # Have to run the traced function to actually generate the trace 353 traced_fn(sample_input.input, *sample_input.args, **sample_input.kwargs) 354 355 # Run the Dtype Analysis 356 graph = traced_fn.graph # Note this is a cached graph 357 input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)] 358 input_tensors += [ 359 v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor) 360 ] 361 self.prop_dtype_on_graph(graph, input_tensors) 362 self.assert_output_dtype_equal(expected_res, graph) 363 364 @ops([op for op in op_db if op.aten_name in custom_rules_works_list]) 365 def test_custom_rules(self, device, dtype, op): 366 self.custom_rules_test_base(device, dtype, op) 367 368 @ops([op for op in op_db if op.aten_name in custom_rules_works_list]) 369 def test_custom_rules_ints(self, device, dtype, op): 370 # This is done because opinfos currently only runs on floats. 371 # Return fn, inputs_fn for all 372 if dtype == torch.float32: 373 dtype = torch.int32 374 else: 375 dtype = torch.int64 376 377 # Because ints are not always implemented, we need to allow for eager to fail 378 self.custom_rules_test_base(device, dtype, op, allow_eager_fail=True) 379 380 @expectedFailure 381 @ops([op for op in op_db if op.aten_name in custom_rules_expected_failure_list]) 382 def test_custom_rules_expected_failure(self, device, dtype, op): 383 self.custom_rules_test_base(device, dtype, op) 384 385 386TestDtypeCustomRulesCPU = None 387# This creates TestDtypeCustomRulesCPU 388instantiate_device_type_tests(TestDtypeCustomRules, globals(), only_for=("cpu",)) 389