1# Owner(s): ["module: unknown"] 2 3from functools import partial 4from textwrap import dedent 5 6import torch 7from torch.testing import FileCheck 8from torch.testing._internal.common_device_type import ( 9 instantiate_device_type_tests, 10 OpDTypes, 11 ops, 12) 13from torch.testing._internal.common_jit import ( 14 check_against_reference, 15 JitCommonTestCase, 16) 17from torch.testing._internal.common_methods_invocations import op_db 18from torch.testing._internal.common_utils import ( 19 clone_input_helper, 20 first_sample, 21 IS_SANDCASTLE, 22 run_tests, 23 TestCase, 24 unMarkDynamoStrictTest, 25) 26from torch.testing._internal.jit_metaprogramming_utils import ( 27 check_alias_annotation, 28 create_script_fn, 29 create_traced_fn, 30) 31from torch.testing._internal.jit_utils import ( 32 disable_autodiff_subgraph_inlining, 33 is_lambda, 34) 35 36 37# variant testing is only done with torch.float and torch.cfloat to avoid 38# excessive test times and maximize signal to noise ratio 39_variant_ops = partial( 40 ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat) 41) 42 43 44# Tests operators for consistency between JIT and eager, also checks 45# correctness of JIT specific alias schemas and intended 46# autodifferentiation behavior. 47# Inherits from JitCommonTestCase instead of TestCase directly to share 48# functionality with original test_jit.py method operator tests 49@unMarkDynamoStrictTest 50class TestJit(JitCommonTestCase): 51 exact_dtype = True 52 53 # Tests that the forward and backward passes of operations produce the 54 # same values for the cross-product of op variants (function, method, inplace) 55 # and runtimes (eager, traced, scripted). 56 # TODO WARNING: inplace x {traced, scripted} not currently tested 57 @_variant_ops(op_db) 58 def test_variant_consistency_jit(self, device, dtype, op): 59 _requires_grad = dtype in op.supported_backward_dtypes( 60 torch.device(device).type 61 ) 62 63 include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex 64 samples = op.sample_inputs( 65 device, 66 dtype, 67 requires_grad=_requires_grad, 68 include_conjugated_inputs=include_conjugated_inputs, 69 ) 70 71 # Acquires variants to test 72 func = op.get_op() 73 method = op.get_method() 74 variants = { 75 # TODO: inplace tests currently fail, fix and add inplace variant 76 "function": func, 77 "method": method, 78 } 79 80 # scripting strips the torch.ops prefix from these operators 81 # incorrectly; don't bother testing this case. Count this 82 # as "testing" 83 if isinstance(func, torch._ops.OpOverload): 84 self.skipTest("variant consistency doesn't work on torch.ops") 85 86 # TODO: find better way to standardize on op registration itself.. 87 has_fake_function = op.name in ["resize_", "resize_as_"] 88 89 if has_fake_function: 90 variants = {"method": getattr(torch.Tensor, op.name)} 91 samples = op.sample_inputs(device, dtype, requires_grad=False) 92 93 tested = False 94 for sample in samples: 95 # Test traced and scripted consistency 96 for func_type, variant in variants.items(): 97 if variant is None: 98 continue 99 100 # scripting and check_alias_analysis do not work with lambdas 101 # lambdas are typically used as a way to simulate methods without 102 # functional variants, so rely on the other variant for testing 103 # for now 104 if is_lambda(variant): 105 continue 106 107 tested = True 108 try: 109 self.indiv_variant_test_jit( 110 device, dtype, op, sample, func_type, variant, has_fake_function 111 ) 112 except Exception as e: 113 variant_error_info = dedent( 114 f""" 115 Error testing {op.name} {func_type} variant 116 with dtype: {dtype} 117 with inputs {sample}: 118 """ 119 ) 120 raise Exception(variant_error_info) from e # noqa: TRY002 121 122 assert tested, "JIT Test does not execute any logic" 123 124 def indiv_variant_test_jit( 125 self, device, dtype, op, sample, func_type, variant, has_fake_function 126 ): 127 _requires_grad = dtype in op.supported_backward_dtypes( 128 torch.device(device).type 129 ) 130 support_script = op.supports_scripting 131 # Create accessor for script function variant 132 name = op.name + "_" if func_type == "inplace" else op.name 133 134 # run with disable_autodiff_subgraph_inlining(True) to test 135 # autodiff support. Context manager forces the graph to contain 136 # DifferentiableGraph nodes if they are present 137 with disable_autodiff_subgraph_inlining(): 138 # Check scripted forward, grad, and grad grad 139 if support_script: 140 script_fn = create_script_fn(self, name, func_type) 141 142 def out_fn(output): 143 # Processes the output for autograd 144 if sample.output_process_fn_grad is not None: 145 return sample.output_process_fn_grad(output) 146 return output 147 148 def get_sample(): 149 return ( 150 clone_input_helper(sample.input) 151 if op.name[-1] == "_" 152 else sample.input 153 ) 154 155 if support_script: 156 check_against_reference( 157 self, 158 script_fn, 159 op.get_op(), 160 out_fn, 161 (get_sample(),) + sample.args, 162 sample.kwargs, 163 no_grad=not _requires_grad, 164 no_gradgrad=not op.supports_gradgrad, 165 ) 166 167 # Check traced forward, grad, and grad grad 168 # TODO: fix tracing here 169 supports_tracing = op.supports_tracing and not has_fake_function 170 if op.assert_jit_shape_analysis: 171 self.assertTrue(supports_tracing) 172 173 if supports_tracing: 174 traced_fn = create_traced_fn(self, variant) 175 check_against_reference( 176 self, 177 traced_fn, 178 op.get_op(), 179 out_fn, 180 (get_sample(),) + sample.args, 181 sample.kwargs, 182 no_grad=not _requires_grad, 183 no_gradgrad=not op.supports_gradgrad, 184 ) 185 186 # Check alias annotation schema for correctness (make 187 # sure inputs that aren't supposed to be modified aren't) 188 # Note: only runs in float32 because schema isn't affected by dtype, 189 # so running it on all dtypes is would be excessive 190 if dtype == torch.float32: 191 # TODO: no reason why we cant run this with tracing graph 192 if support_script and op.name != "rsub": 193 check_alias_annotation( 194 name, 195 (get_sample(),) + sample.args, 196 sample.kwargs, 197 func_type=func_type, 198 aten_name=op.aten_name, 199 ) 200 201 # TODO: use script graph as well 202 checked_shape_analysis = False 203 if supports_tracing: 204 out = variant(get_sample(), *sample.args, **sample.kwargs) 205 206 # right now, tuple of outputs and tensor output supported 207 # TODO: list of tensor outputs 208 tuple_of_tensors = isinstance(out, tuple) and all( 209 isinstance(elem, torch.Tensor) for elem in out 210 ) 211 212 if isinstance(out, torch.Tensor) or tuple_of_tensors: 213 if tuple_of_tensors: 214 sizes = [elem.size() for elem in out] 215 else: 216 sizes = out.size() 217 self.checkShapeAnalysis( 218 sizes, traced_fn.graph, op.assert_jit_shape_analysis 219 ) 220 checked_shape_analysis = True 221 if op.assert_jit_shape_analysis: 222 self.assertTrue(checked_shape_analysis) 223 224 # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample 225 if dtype is torch.float32: 226 # Sandcastle doesn't fuse nodes 227 if IS_SANDCASTLE: 228 # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs 229 nonfusible_nodes = ( 230 op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes 231 ) 232 fusible_nodes = [] 233 else: 234 nonfusible_nodes = op.autodiff_nonfusible_nodes 235 fusible_nodes = op.autodiff_fusible_nodes 236 237 if supports_tracing: 238 self.assertAutodiffNode( 239 traced_fn.last_graph, 240 op.assert_autodiffed, 241 nonfusible_nodes, 242 fusible_nodes, 243 ) 244 if support_script: 245 self.assertAutodiffNode( 246 script_fn.last_graph, 247 op.assert_autodiffed, 248 nonfusible_nodes, 249 fusible_nodes, 250 ) 251 252 # alias testing is only done with torch.float for the same reason 253 _alias_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float,)) 254 255 @_alias_ops(op for op in op_db if op.aliases) 256 def test_jit_alias_remapping(self, device, dtype, op): 257 # NOTE: only tests on first sample 258 samples = op.sample_inputs(device, dtype, requires_grad=True) 259 sample = first_sample(self, samples) 260 261 # [Scripting Data Preparation] 262 # Prepare data for test scripting 263 # Below we prepare strings of args/kwargs with and without type annotations. 264 # These strings are inserted into function template strings which is then torch scripted. 265 # - args string is ["t0"] corresponding to the "input" tensor required by the op 266 # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example, 267 # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0)) 268 args = ["t0"] 269 270 def quote_strs(v): 271 if isinstance(v, str): 272 return f"'{v}'" 273 274 return str(v) 275 276 args_kw = ( 277 args 278 + [f"{v}" for v in sample.args] 279 + [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()] 280 ) 281 282 # Prepare data for test tracing 283 sample_args_kwargs = () 284 if len(sample.args) > 0: 285 sample_args_kwargs += (sample.args,) 286 if len(sample.kwargs) > 0: 287 sample_args_kwargs += (sample.kwargs,) 288 289 original_name = op.aten_name 290 original_name_inplace = original_name + "_" 291 expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype 292 293 for a_op in op.aliases: 294 inplace = a_op.inplace_variant 295 method_or_inplace = [a_op.inplace_variant, a_op.method_variant] 296 variants = ( 297 v 298 for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) 299 if v is not None 300 ) 301 302 # Test scripting: 303 for variant in variants: 304 variant_name = variant.__name__ 305 op_name = original_name_inplace if variant is inplace else original_name 306 307 if variant in method_or_inplace: 308 fn_template = """ 309 def _fn(t0{c}): 310 return t0.{alias_name}({args_kw}) 311 """ 312 # remove the first input tensor 313 script = fn_template.format( 314 c=", " if len(args_kw[1:]) > 1 else "", 315 args_kw=", ".join(args_kw[1:]), 316 alias_name=variant_name, 317 ) 318 else: 319 fn_template = """ 320 def _fn({args}): 321 return variant({args_kw}) 322 """ 323 script = fn_template.format( 324 args=", ".join(args), 325 args_kw=", ".join(args_kw), 326 ) 327 328 # Required to avoid undefined value: tensor error in JIT 329 # compilation of the function template 330 script = script.replace("tensor(", "torch.tensor(") 331 332 scripted = torch.jit.CompilationUnit(script)._fn 333 334 if variant is inplace and not torch.can_cast(expected_dtype, dtype): 335 try: 336 inp = clone_input_helper(sample.input) 337 scripted(inp) 338 except Exception as e: 339 continue 340 self.fail( 341 "Inplace operation on integer tensor that should be promoted to float didn't fail!" 342 ) 343 344 inp = clone_input_helper(sample.input) 345 scripted(inp) 346 inp = clone_input_helper(sample.input) 347 graph = scripted.graph_for(inp) 348 FileCheck().check(op.aten_name).check_not(variant_name).run(graph) 349 350 # Test tracing: 351 for variant in variants: 352 variant_name = variant.__name__ 353 op_name = original_name_inplace if variant is inplace else original_name 354 355 def _fn(*sample_args, **sample_kwargs): 356 return variant(*sample_args, **sample_kwargs) 357 358 inp = (clone_input_helper(sample.input),) + sample_args_kwargs 359 traced = torch.jit.trace(_fn, *inp) 360 inp = (clone_input_helper(sample.input),) + sample_args_kwargs 361 traced(*inp) 362 inp = (clone_input_helper(sample.input),) + sample_args_kwargs 363 graph = traced.graph_for(*inp) 364 FileCheck().check(op_name).check_not(variant_name).run(graph) 365 366 367instantiate_device_type_tests(TestJit, globals()) 368 369if __name__ == "__main__": 370 TestCase._default_dtype_check_enabled = True 371 run_tests() 372