1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import torch 6from torch.utils._pytree import tree_map 7import unittest 8 9from torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO 10from torch.fx.operator_schemas import normalize_function 11from torch._subclasses.schema_check_mode import SchemaCheckMode 12from torch.utils._python_dispatch import TorchDispatchMode 13from torch.testing._internal.common_methods_invocations import op_db 14from torch.testing._internal.jit_utils import JitTestCase 15from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests 16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17sys.path.append(pytorch_test_dir) 18 19def secretly_aliasing(x): 20 return x.view(-1) 21 22def secretly_mutating(x): 23 x.mul_(2) 24 return x * 3 25 26def output_is_input(x): 27 return x 28 29custom_lib = torch.library.Library("bad_schemas", "DEF") # noqa: TOR901 30custom_lib.define("secretly_aliasing(Tensor x) -> Tensor") 31custom_lib.define("secretly_mutating(Tensor x) -> Tensor") 32custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)") 33 34custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU") # noqa: TOR901 35custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing) 36custom_lib_cpu.impl("secretly_mutating", secretly_mutating) 37custom_lib_cpu.impl("output_is_input", output_is_input) 38 39custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta") # noqa: TOR901 40custom_lib_meta.impl("secretly_aliasing", secretly_aliasing) 41custom_lib_meta.impl("secretly_mutating", secretly_mutating) 42custom_lib_meta.impl("output_is_input", output_is_input) 43 44# This TorchDispatchTensor Subclass is used to simulate an incorrect schema 45# which is then used to test that SchemaCheckMode behaves as expected 46 47class IncorrectAliasTensor(torch.Tensor): 48 ALIAS_ARG_OUT = {"aten::add"} 49 ALIAS_OUT_OUT = {"aten::aminmax"} 50 MUTATE_ARGS_OUT = {"aten::sub"} 51 52 elem: torch.Tensor 53 54 __slots__ = ['elem'] 55 56 @staticmethod 57 def __new__(cls, elem, *args, **kwargs): 58 # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any 59 # memory for the class in question, but it should still 60 # advertise the same device as before 61 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 62 cls, elem.size(), 63 strides=elem.stride(), storage_offset=elem.storage_offset(), 64 # TODO: clone storage aliasing 65 dtype=elem.dtype, layout=elem.layout, 66 device=elem.device, requires_grad=kwargs.get("requires_grad", False) 67 ) 68 # ...the real tensor is held as an element on the tensor. 69 r.elem = elem.detach() if r.requires_grad else elem 70 return r 71 72 def __repr__(self): 73 return super().__repr__(tensor_contents=f"{self.elem}") 74 75 @classmethod 76 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 77 def unwrap(e): 78 return e.elem if isinstance(e, cls) else e 79 80 def wrap(e): 81 return cls(e) if isinstance(e, torch.Tensor) else e 82 unwrapped_args = tree_map(unwrap, args) 83 out = func(*unwrapped_args, **tree_map(unwrap, kwargs)) 84 if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT: 85 args[0].elem = out 86 if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT: 87 args[0].elem = torch.rand(args[0].elem.shape) 88 if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT: 89 incorrect_out = list(out) 90 incorrect_out[0] = incorrect_out[1] 91 return tree_map(wrap, tuple(incorrect_out)) 92 93 return tree_map(wrap, out) 94 95# Tests various schema checking functionalities. 96class TestSchemaCheck(JitTestCase): 97 def setUp(self): 98 if TEST_WITH_TORCHDYNAMO: 99 self.skipTest("SchemaCheckMode is ignored by dynamo") 100 super().setUp() 101 102 # Tests that SchemaCheckMode records operator order with grad 103 def test_schema_check_mode_operator_order(self): 104 with SchemaCheckMode() as schema_check: 105 x = torch.rand((3, 3), requires_grad=True) 106 x.relu().sin() 107 self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops) 108 109 # Tests that SchemaCheckMode records operator order without grad 110 def test_schema_check_mode_operator_order_without_grad(self): 111 with SchemaCheckMode() as schema_check: 112 x = torch.rand((3, 3), requires_grad=False) 113 x.relu().sin() 114 self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) 115 116 # Tests that SchemaCheckMode records mutations and aliases with none expected 117 def test_schema_check_mode_mutated_aliasing_none(self): 118 # NB: previously requires_grad=True, but this induces a detach for 119 # saved variable 120 x = torch.rand((3, 3)) 121 with SchemaCheckMode() as schema_check: 122 actual = x.relu().sin() 123 self.assertEqual([], schema_check.mutated) 124 self.assertEqual([], schema_check.aliasing) 125 126 # Tests that SchemaCheckMode records mutations and aliases with mutation expected 127 def test_schema_check_mode_mutated_aliasing_mutation(self): 128 actual = torch.rand((3, 3), requires_grad=False) 129 with SchemaCheckMode() as schema_check: 130 actual.sinh_() 131 self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated) 132 self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing) 133 134 # Tests that SchemaCheckMode records mutations and aliases with resize_ 135 def test_schema_check_mode_mutated_aliasing_resize_(self): 136 actual = torch.rand((3, 3), requires_grad=False) 137 with SchemaCheckMode() as schema_check: 138 actual.resize_(9) 139 self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) 140 self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing) 141 142 # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs 143 def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): 144 actual = torch.rand((3, 3)) 145 y = actual 146 with SchemaCheckMode() as schema_check: 147 actual.add_(y) 148 self.assertEqual( 149 [ 150 ('aten::add_', 'input'), 151 ('aten::add_', 'other') 152 ], 153 schema_check.mutated 154 ) 155 self.assertEqual( 156 [ 157 ('aten::add_', 'input', 'output_0'), 158 ('aten::add_', 'other', 'output_0') 159 ], 160 schema_check.aliasing 161 ) 162 163 # Tests that SchemaCheckMode records mutations and alias with as_strided 164 def test_schema_check_mode_mutated_aliasing_as_strided(self): 165 x = torch.rand((3, 6, 4)) 166 with SchemaCheckMode() as schema_check: 167 x.as_strided_([3, 6, 4], [9, 1, 1]) 168 self.assertEqual( 169 [ 170 ('aten::as_strided_', 'input') 171 ], 172 schema_check.mutated 173 ) 174 self.assertEqual( 175 [ 176 ('aten::as_strided_', 'input', 'output_0') 177 ], 178 schema_check.aliasing 179 ) 180 181 # Tests that SchemaCheckMode records mutations and aliases with multiple outputs 182 def test_schema_check_mode_mutated_aliasing_multiple_outputs(self): 183 x = torch.arange(9.) 184 m_actual = torch.arange(9.) 185 e_actual = torch.zeros([9], dtype=torch.int32) 186 with SchemaCheckMode() as schema_check: 187 torch.frexp(x, out=(m_actual, e_actual)) 188 self.assertEqual( 189 [ 190 ('aten::frexp', 'mantissa'), 191 ('aten::frexp', 'exponent') 192 ], 193 schema_check.mutated 194 ) 195 self.assertEqual( 196 [ 197 ('aten::frexp', 'mantissa', 'output_0'), 198 ('aten::frexp', 'exponent', 'output_1') 199 ], 200 schema_check.aliasing 201 ) 202 203 # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs 204 def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): 205 x = torch.rand((3, 3)) 206 actual = torch.zeros(3) 207 with SchemaCheckMode() as schema_check: 208 torch.aminmax(x, dim=0, out=[actual, actual]) 209 self.assertEqual( 210 [ 211 ('aten::aminmax', 'min'), 212 ('aten::aminmax', 'max') 213 ], 214 schema_check.mutated 215 ) 216 self.assertEqual( 217 [ 218 ('aten::aminmax', 'min', 'output_0'), 219 ('aten::aminmax', 'min', 'output_1'), 220 ('aten::aminmax', 'max', 'output_0'), 221 ('aten::aminmax', 'max', 'output_1') 222 ], 223 schema_check.aliasing 224 ) 225 226 # Tests that SchemaCheckMode wraps torch.Tensor 227 def test_schema_check_mode_functionality(self): 228 x = torch.rand((3, 3), requires_grad=True) 229 expected = x.relu().sin() 230 with SchemaCheckMode(): 231 actual = x.relu().sin() 232 self.assertEqual(expected, actual) 233 234 # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden 235 def test_schema_check_mode_functionality_default_replaced(self): 236 x = torch.rand((3, 3), requires_grad=True) 237 expected = x.add(x, alpha=2) 238 with SchemaCheckMode(): 239 actual = x.add(x, alpha=2) 240 self.assertEqual(expected, actual) 241 242 # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument 243 def test_schema_check_mode_functionality_list_input(self): 244 a = torch.rand((3, 3)) 245 b = torch.rand((3, 3)) 246 c = torch.rand((3, 3)) 247 expected = torch.linalg.multi_dot([a, b, c]) 248 with SchemaCheckMode(): 249 actual = torch.linalg.multi_dot([a, b, c]) 250 self.assertEqual(expected, actual) 251 252 # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation 253 def test_schema_check_mode_functionality_wildcard_after(self): 254 x = torch.rand((3, 3)) 255 expected = x.chunk(6) 256 with SchemaCheckMode(): 257 actual = x.chunk(6) 258 self.assertEqual(expected, actual) 259 260 # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input 261 @unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.") 262 def test_schema_check_mode_functionality_kwarg_tensor(self): 263 x = torch.rand((3, 5)) 264 w = torch.rand(4) 265 expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) 266 with SchemaCheckMode(): 267 actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True) 268 self.assertEqual(expected, actual) 269 270 # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op 271 def test_schema_check_mode_functionality_mutable_inputs(self): 272 expected = torch.rand((3, 3), requires_grad=False) 273 actual = torch.clone(expected) 274 expected.sinh_() 275 with SchemaCheckMode(): 276 actual.sinh_() 277 self.assertEqual(expected, actual) 278 279 # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias 280 def test_schema_check_mode_functionality_aliasing_inputs(self): 281 expected = torch.rand((3, 3)) 282 x = expected 283 actual = torch.clone(expected) 284 y = actual 285 expected.add_(x) 286 with SchemaCheckMode(): 287 actual.add_(y) 288 self.assertEqual(expected, actual) 289 290 # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs 291 def test_schema_check_mode_functionality_with_multiple_outputs(self): 292 x = torch.arange(9.) 293 m_expected, e_expected = torch.frexp(x) 294 m_actual = torch.arange(9.) 295 e_actual = torch.zeros([9], dtype=torch.int32) 296 with SchemaCheckMode(): 297 torch.frexp(x, out=(m_actual, e_actual)) 298 self.assertEqual(m_expected, m_actual) 299 self.assertEqual(e_expected, e_actual) 300 301 # Tests that SchemaCheckMode wraps Torch.tensor with aliasing outputs due to aliasing inputs 302 def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self): 303 x = torch.rand((3, 3)) 304 actual = torch.zeros(3) 305 with SchemaCheckMode(): 306 torch.aminmax(x, dim=0, out=[actual, actual]) 307 self.assertEqual(torch.amax(x, dim=0), actual) 308 309 # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input 310 def test_schema_check_mode_functionality_device_input(self): 311 with SchemaCheckMode(): 312 x = torch.rand((3, 3), device="cpu", dtype=torch.double) 313 y = x + x 314 self.assertEqual(x + x, y) 315 316 # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case 317 def test_schema_check_mode_functionality_training_op(self): 318 x = torch.rand((3, 3), requires_grad=True) 319 batch = torch.nn.BatchNorm1d(3, track_running_stats=True) 320 expected = batch(x) 321 with SchemaCheckMode(): 322 actual = batch(x) 323 self.assertEqual(expected, actual) 324 325 # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case 326 def test_schema_check_mode_functionality_nested_training_op(self): 327 actual = torch.rand((3, 3)) 328 batch = torch.nn.BatchNorm1d(3, track_running_stats=True) 329 expected = torch.clone(actual) 330 expected.sinh_() 331 expected.tanh_() 332 expected.relu_() 333 expected = batch(expected) 334 335 with SchemaCheckMode(): 336 actual.sinh_() 337 actual.tanh_() 338 actual.relu_() 339 actual = batch(actual) 340 self.assertEqual(expected, actual) 341 342 # Tests that SchemaCheckMode wraps Torch.tensor with empty list input 343 def test_schema_check_mode_empty_list_input(self): 344 expected = torch.atleast_1d([]) 345 with SchemaCheckMode(): 346 actual = torch.atleast_1d([]) 347 self.assertEqual(expected, actual) 348 349 # Tests that an exception is raised for a mismatching mutation 350 def test_mutation_check_fail(self): 351 with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): 352 x = torch.rand((3, 3)) 353 y = torch.rand((3, 3)) 354 with SchemaCheckMode(): 355 IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y)) 356 357 # # Tests that an exception is raised for a mismatching mutation over multiple ops 358 def test_mutation_check_fail_multiple_operators(self): 359 with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): 360 x = torch.rand((3, 3)) 361 y = torch.rand((3, 3)) 362 with SchemaCheckMode(): 363 IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y)) 364 365 # Tests that an exception is raised for a mismatching alias 366 def test_alias_check_fail_simple(self): 367 with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): 368 x = torch.rand((3, 3), requires_grad=True) 369 y = torch.rand((3, 3)) 370 with SchemaCheckMode(): 371 IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2) 372 373 # Tests that an exception is raised for a mismatching alias over multiple ops 374 def test_alias_check_fail_multiple_operators(self): 375 with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): 376 x = torch.rand((3, 3), requires_grad=True) 377 y = torch.zeros((3, 3), requires_grad=True) 378 with SchemaCheckMode(): 379 IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2) 380 381 # Tests that an exception is raised for a centered mismatching alias over multiple ops 382 def test_alias_check_fail_multiple_operators_centered(self): 383 with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): 384 x = torch.rand((3, 3), requires_grad=True) 385 y = torch.zeros((3, 3), requires_grad=True) 386 with SchemaCheckMode(): 387 IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu() 388 389 # Tests that an exception is raised for a centered mismatching alias over multiple ops 390 def test_alias_check_fail_outputs_unexpectedly_aliasing(self): 391 with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): 392 x = torch.rand((3, 3)) 393 with SchemaCheckMode() as s: 394 IncorrectAliasTensor(x).aminmax(dim=0) 395 396 # When this file was written, python op registration didn't exist. 397 # It's probably worth re-writing the entire file to use it, 398 # but instead I just added extra tests. 399 def test_alias_check_fail_custom_ops_secretly_aliasing(self): 400 def f(x): 401 return torch.ops.bad_schemas.secretly_aliasing(x) 402 403 x = torch.rand((3, 3)) 404 with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"): 405 with SchemaCheckMode() as s: 406 out = f(x) 407 408 def test_alias_check_fail_custom_ops_secretly_mutating(self): 409 def f(x): 410 return torch.ops.bad_schemas.secretly_mutating(x) 411 412 x = torch.rand((3, 3)) 413 with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"): 414 with SchemaCheckMode() as s: 415 out = f(x) 416 417 def test_alias_check_fail_custom_ops_output_is_input(self): 418 def f(x): 419 return torch.ops.bad_schemas.output_is_input(x) 420 421 x = torch.rand((3, 3)) 422 with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"): 423 with SchemaCheckMode() as s: 424 out = f(x) 425 426 # Tests that is_alias_of returns as expected 427 def test_is_alias_of_basic(self): 428 x = torch.rand((3, 3), requires_grad=True) 429 y = torch.rand((3, 3), requires_grad=True) 430 y = x.add(x, alpha=2) 431 self.assertTrue(torch._C._is_alias_of(x, x)) 432 self.assertFalse(torch._C._is_alias_of(x, y)) 433 434 # Tests that is_alias_of returns as expected with empty containers 435 def test_is_alias_of_empty_container(self): 436 x = [] 437 y = torch.rand((3, 3), requires_grad=True) 438 self.assertFalse(torch._C._is_alias_of(x, x)) 439 self.assertFalse(torch._C._is_alias_of(x, y)) 440 441 # Tests that overlaps returns as expected 442 def test_overlaps_basic(self): 443 x = torch.rand((3, 3), requires_grad=True) 444 y = torch.rand((3, 3), requires_grad=True) 445 z = [x, y] 446 self.assertTrue(torch._C._overlaps(x, x)) 447 self.assertFalse(torch._C._overlaps(x, y)) 448 self.assertTrue(torch._C._overlaps(z, x)) 449 self.assertTrue(torch._C._overlaps(z, y)) 450 451 # Tests that overlaps returns correctly with empty containers 452 def test_overlaps_empty_container(self): 453 x = [] 454 y = [torch.rand((3, 3), requires_grad=True)] 455 # Empty containers return false 456 self.assertFalse(torch._C._overlaps(y, x)) 457 self.assertTrue(torch._C._overlaps(y, y)) 458 459 # Tests that SchemaInfo Bindings work as expected 460 def test_schema_info_bind_basic(self): 461 class SchemaInfoBindTestMode(TorchDispatchMode): 462 def __init__(self, test_self): 463 self.test_self = test_self 464 465 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 466 named_arg_list = normalize_function( 467 func, 468 args, 469 kwargs, 470 normalize_to_only_use_kwargs=True 471 ).kwargs 472 schema_info_value_test = torch._C._SchemaInfo(func._schema) 473 schema_info_values_test = torch._C._SchemaInfo(func._schema) 474 self.test_self.assertFalse(schema_info_value_test.may_alias( 475 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 476 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 477 self.test_self.assertFalse(schema_info_values_test.may_alias( 478 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 479 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 480 for i in named_arg_list: 481 schema_info_value_test.add_argument_value(i, named_arg_list[i]) 482 schema_info_values_test.add_argument_values(named_arg_list) 483 self.test_self.assertTrue(schema_info_value_test.may_alias( 484 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 485 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 486 self.test_self.assertTrue(schema_info_values_test.may_alias( 487 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 488 torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 489 490 return func(*args, **kwargs) 491 x = torch.rand((3, 3)) 492 with SchemaInfoBindTestMode(self) as schemaInfoCheck: 493 x.add(x) 494 495 496class TestSchemaCheckModeOpInfo(JitTestCase): 497 @ops(op_db, dtypes=OpDTypes.supported) 498 def test_schema_correctness(self, device, dtype, op): 499 # Currently torch.equal isn't supported with torch.complex32 500 # There's also errors with complex64 and complex128 501 if (dtype == torch.complex32): 502 return 503 for sample in op.sample_inputs(device, dtype, requires_grad=False): 504 with SchemaCheckMode(): 505 op(sample.input, *sample.args, **sample.kwargs) 506 507instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda")) 508 509if __name__ == '__main__': 510 run_tests() 511