1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map 7*da0073e9SAndroid Build Coastguard Workerimport unittest 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO 10*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import normalize_function 11*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.schema_check_mode import SchemaCheckMode 12*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests 16*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerdef secretly_aliasing(x): 20*da0073e9SAndroid Build Coastguard Worker return x.view(-1) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerdef secretly_mutating(x): 23*da0073e9SAndroid Build Coastguard Worker x.mul_(2) 24*da0073e9SAndroid Build Coastguard Worker return x * 3 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerdef output_is_input(x): 27*da0073e9SAndroid Build Coastguard Worker return x 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workercustom_lib = torch.library.Library("bad_schemas", "DEF") # noqa: TOR901 30*da0073e9SAndroid Build Coastguard Workercustom_lib.define("secretly_aliasing(Tensor x) -> Tensor") 31*da0073e9SAndroid Build Coastguard Workercustom_lib.define("secretly_mutating(Tensor x) -> Tensor") 32*da0073e9SAndroid Build Coastguard Workercustom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)") 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU") # noqa: TOR901 35*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu.impl("secretly_aliasing", secretly_aliasing) 36*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu.impl("secretly_mutating", secretly_mutating) 37*da0073e9SAndroid Build Coastguard Workercustom_lib_cpu.impl("output_is_input", output_is_input) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Workercustom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta") # noqa: TOR901 40*da0073e9SAndroid Build Coastguard Workercustom_lib_meta.impl("secretly_aliasing", secretly_aliasing) 41*da0073e9SAndroid Build Coastguard Workercustom_lib_meta.impl("secretly_mutating", secretly_mutating) 42*da0073e9SAndroid Build Coastguard Workercustom_lib_meta.impl("output_is_input", output_is_input) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker# This TorchDispatchTensor Subclass is used to simulate an incorrect schema 45*da0073e9SAndroid Build Coastguard Worker# which is then used to test that SchemaCheckMode behaves as expected 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerclass IncorrectAliasTensor(torch.Tensor): 48*da0073e9SAndroid Build Coastguard Worker ALIAS_ARG_OUT = {"aten::add"} 49*da0073e9SAndroid Build Coastguard Worker ALIAS_OUT_OUT = {"aten::aminmax"} 50*da0073e9SAndroid Build Coastguard Worker MUTATE_ARGS_OUT = {"aten::sub"} 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker elem: torch.Tensor 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker __slots__ = ['elem'] 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker @staticmethod 57*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, *args, **kwargs): 58*da0073e9SAndroid Build Coastguard Worker # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any 59*da0073e9SAndroid Build Coastguard Worker # memory for the class in question, but it should still 60*da0073e9SAndroid Build Coastguard Worker # advertise the same device as before 61*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 62*da0073e9SAndroid Build Coastguard Worker cls, elem.size(), 63*da0073e9SAndroid Build Coastguard Worker strides=elem.stride(), storage_offset=elem.storage_offset(), 64*da0073e9SAndroid Build Coastguard Worker # TODO: clone storage aliasing 65*da0073e9SAndroid Build Coastguard Worker dtype=elem.dtype, layout=elem.layout, 66*da0073e9SAndroid Build Coastguard Worker device=elem.device, requires_grad=kwargs.get("requires_grad", False) 67*da0073e9SAndroid Build Coastguard Worker ) 68*da0073e9SAndroid Build Coastguard Worker # ...the real tensor is held as an element on the tensor. 69*da0073e9SAndroid Build Coastguard Worker r.elem = elem.detach() if r.requires_grad else elem 70*da0073e9SAndroid Build Coastguard Worker return r 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 73*da0073e9SAndroid Build Coastguard Worker return super().__repr__(tensor_contents=f"{self.elem}") 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker @classmethod 76*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 77*da0073e9SAndroid Build Coastguard Worker def unwrap(e): 78*da0073e9SAndroid Build Coastguard Worker return e.elem if isinstance(e, cls) else e 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def wrap(e): 81*da0073e9SAndroid Build Coastguard Worker return cls(e) if isinstance(e, torch.Tensor) else e 82*da0073e9SAndroid Build Coastguard Worker unwrapped_args = tree_map(unwrap, args) 83*da0073e9SAndroid Build Coastguard Worker out = func(*unwrapped_args, **tree_map(unwrap, kwargs)) 84*da0073e9SAndroid Build Coastguard Worker if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT: 85*da0073e9SAndroid Build Coastguard Worker args[0].elem = out 86*da0073e9SAndroid Build Coastguard Worker if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT: 87*da0073e9SAndroid Build Coastguard Worker args[0].elem = torch.rand(args[0].elem.shape) 88*da0073e9SAndroid Build Coastguard Worker if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT: 89*da0073e9SAndroid Build Coastguard Worker incorrect_out = list(out) 90*da0073e9SAndroid Build Coastguard Worker incorrect_out[0] = incorrect_out[1] 91*da0073e9SAndroid Build Coastguard Worker return tree_map(wrap, tuple(incorrect_out)) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker return tree_map(wrap, out) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker# Tests various schema checking functionalities. 96*da0073e9SAndroid Build Coastguard Workerclass TestSchemaCheck(JitTestCase): 97*da0073e9SAndroid Build Coastguard Worker def setUp(self): 98*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_TORCHDYNAMO: 99*da0073e9SAndroid Build Coastguard Worker self.skipTest("SchemaCheckMode is ignored by dynamo") 100*da0073e9SAndroid Build Coastguard Worker super().setUp() 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records operator order with grad 103*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_operator_order(self): 104*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 105*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 106*da0073e9SAndroid Build Coastguard Worker x.relu().sin() 107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records operator order without grad 110*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_operator_order_without_grad(self): 111*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 112*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=False) 113*da0073e9SAndroid Build Coastguard Worker x.relu().sin() 114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records mutations and aliases with none expected 117*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_mutated_aliasing_none(self): 118*da0073e9SAndroid Build Coastguard Worker # NB: previously requires_grad=True, but this induces a detach for 119*da0073e9SAndroid Build Coastguard Worker # saved variable 120*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 121*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 122*da0073e9SAndroid Build Coastguard Worker actual = x.relu().sin() 123*da0073e9SAndroid Build Coastguard Worker self.assertEqual([], schema_check.mutated) 124*da0073e9SAndroid Build Coastguard Worker self.assertEqual([], schema_check.aliasing) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records mutations and aliases with mutation expected 127*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_mutated_aliasing_mutation(self): 128*da0073e9SAndroid Build Coastguard Worker actual = torch.rand((3, 3), requires_grad=False) 129*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 130*da0073e9SAndroid Build Coastguard Worker actual.sinh_() 131*da0073e9SAndroid Build Coastguard Worker self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated) 132*da0073e9SAndroid Build Coastguard Worker self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records mutations and aliases with resize_ 135*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_mutated_aliasing_resize_(self): 136*da0073e9SAndroid Build Coastguard Worker actual = torch.rand((3, 3), requires_grad=False) 137*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 138*da0073e9SAndroid Build Coastguard Worker actual.resize_(9) 139*da0073e9SAndroid Build Coastguard Worker self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) 140*da0073e9SAndroid Build Coastguard Worker self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs 143*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): 144*da0073e9SAndroid Build Coastguard Worker actual = torch.rand((3, 3)) 145*da0073e9SAndroid Build Coastguard Worker y = actual 146*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 147*da0073e9SAndroid Build Coastguard Worker actual.add_(y) 148*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 149*da0073e9SAndroid Build Coastguard Worker [ 150*da0073e9SAndroid Build Coastguard Worker ('aten::add_', 'input'), 151*da0073e9SAndroid Build Coastguard Worker ('aten::add_', 'other') 152*da0073e9SAndroid Build Coastguard Worker ], 153*da0073e9SAndroid Build Coastguard Worker schema_check.mutated 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 156*da0073e9SAndroid Build Coastguard Worker [ 157*da0073e9SAndroid Build Coastguard Worker ('aten::add_', 'input', 'output_0'), 158*da0073e9SAndroid Build Coastguard Worker ('aten::add_', 'other', 'output_0') 159*da0073e9SAndroid Build Coastguard Worker ], 160*da0073e9SAndroid Build Coastguard Worker schema_check.aliasing 161*da0073e9SAndroid Build Coastguard Worker ) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records mutations and alias with as_strided 164*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_mutated_aliasing_as_strided(self): 165*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 6, 4)) 166*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 167*da0073e9SAndroid Build Coastguard Worker x.as_strided_([3, 6, 4], [9, 1, 1]) 168*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 169*da0073e9SAndroid Build Coastguard Worker [ 170*da0073e9SAndroid Build Coastguard Worker ('aten::as_strided_', 'input') 171*da0073e9SAndroid Build Coastguard Worker ], 172*da0073e9SAndroid Build Coastguard Worker schema_check.mutated 173*da0073e9SAndroid Build Coastguard Worker ) 174*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 175*da0073e9SAndroid Build Coastguard Worker [ 176*da0073e9SAndroid Build Coastguard Worker ('aten::as_strided_', 'input', 'output_0') 177*da0073e9SAndroid Build Coastguard Worker ], 178*da0073e9SAndroid Build Coastguard Worker schema_check.aliasing 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records mutations and aliases with multiple outputs 182*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_mutated_aliasing_multiple_outputs(self): 183*da0073e9SAndroid Build Coastguard Worker x = torch.arange(9.) 184*da0073e9SAndroid Build Coastguard Worker m_actual = torch.arange(9.) 185*da0073e9SAndroid Build Coastguard Worker e_actual = torch.zeros([9], dtype=torch.int32) 186*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 187*da0073e9SAndroid Build Coastguard Worker torch.frexp(x, out=(m_actual, e_actual)) 188*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 189*da0073e9SAndroid Build Coastguard Worker [ 190*da0073e9SAndroid Build Coastguard Worker ('aten::frexp', 'mantissa'), 191*da0073e9SAndroid Build Coastguard Worker ('aten::frexp', 'exponent') 192*da0073e9SAndroid Build Coastguard Worker ], 193*da0073e9SAndroid Build Coastguard Worker schema_check.mutated 194*da0073e9SAndroid Build Coastguard Worker ) 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 196*da0073e9SAndroid Build Coastguard Worker [ 197*da0073e9SAndroid Build Coastguard Worker ('aten::frexp', 'mantissa', 'output_0'), 198*da0073e9SAndroid Build Coastguard Worker ('aten::frexp', 'exponent', 'output_1') 199*da0073e9SAndroid Build Coastguard Worker ], 200*da0073e9SAndroid Build Coastguard Worker schema_check.aliasing 201*da0073e9SAndroid Build Coastguard Worker ) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs 204*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): 205*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 206*da0073e9SAndroid Build Coastguard Worker actual = torch.zeros(3) 207*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as schema_check: 208*da0073e9SAndroid Build Coastguard Worker torch.aminmax(x, dim=0, out=[actual, actual]) 209*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 210*da0073e9SAndroid Build Coastguard Worker [ 211*da0073e9SAndroid Build Coastguard Worker ('aten::aminmax', 'min'), 212*da0073e9SAndroid Build Coastguard Worker ('aten::aminmax', 'max') 213*da0073e9SAndroid Build Coastguard Worker ], 214*da0073e9SAndroid Build Coastguard Worker schema_check.mutated 215*da0073e9SAndroid Build Coastguard Worker ) 216*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 217*da0073e9SAndroid Build Coastguard Worker [ 218*da0073e9SAndroid Build Coastguard Worker ('aten::aminmax', 'min', 'output_0'), 219*da0073e9SAndroid Build Coastguard Worker ('aten::aminmax', 'min', 'output_1'), 220*da0073e9SAndroid Build Coastguard Worker ('aten::aminmax', 'max', 'output_0'), 221*da0073e9SAndroid Build Coastguard Worker ('aten::aminmax', 'max', 'output_1') 222*da0073e9SAndroid Build Coastguard Worker ], 223*da0073e9SAndroid Build Coastguard Worker schema_check.aliasing 224*da0073e9SAndroid Build Coastguard Worker ) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps torch.Tensor 227*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality(self): 228*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 229*da0073e9SAndroid Build Coastguard Worker expected = x.relu().sin() 230*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 231*da0073e9SAndroid Build Coastguard Worker actual = x.relu().sin() 232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden 235*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_default_replaced(self): 236*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 237*da0073e9SAndroid Build Coastguard Worker expected = x.add(x, alpha=2) 238*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 239*da0073e9SAndroid Build Coastguard Worker actual = x.add(x, alpha=2) 240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument 243*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_list_input(self): 244*da0073e9SAndroid Build Coastguard Worker a = torch.rand((3, 3)) 245*da0073e9SAndroid Build Coastguard Worker b = torch.rand((3, 3)) 246*da0073e9SAndroid Build Coastguard Worker c = torch.rand((3, 3)) 247*da0073e9SAndroid Build Coastguard Worker expected = torch.linalg.multi_dot([a, b, c]) 248*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 249*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.multi_dot([a, b, c]) 250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation 253*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_wildcard_after(self): 254*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 255*da0073e9SAndroid Build Coastguard Worker expected = x.chunk(6) 256*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 257*da0073e9SAndroid Build Coastguard Worker actual = x.chunk(6) 258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input 261*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.") 262*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_kwarg_tensor(self): 263*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 5)) 264*da0073e9SAndroid Build Coastguard Worker w = torch.rand(4) 265*da0073e9SAndroid Build Coastguard Worker expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) 266*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 267*da0073e9SAndroid Build Coastguard Worker actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True) 268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op 271*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_mutable_inputs(self): 272*da0073e9SAndroid Build Coastguard Worker expected = torch.rand((3, 3), requires_grad=False) 273*da0073e9SAndroid Build Coastguard Worker actual = torch.clone(expected) 274*da0073e9SAndroid Build Coastguard Worker expected.sinh_() 275*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 276*da0073e9SAndroid Build Coastguard Worker actual.sinh_() 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias 280*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_aliasing_inputs(self): 281*da0073e9SAndroid Build Coastguard Worker expected = torch.rand((3, 3)) 282*da0073e9SAndroid Build Coastguard Worker x = expected 283*da0073e9SAndroid Build Coastguard Worker actual = torch.clone(expected) 284*da0073e9SAndroid Build Coastguard Worker y = actual 285*da0073e9SAndroid Build Coastguard Worker expected.add_(x) 286*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 287*da0073e9SAndroid Build Coastguard Worker actual.add_(y) 288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs 291*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_with_multiple_outputs(self): 292*da0073e9SAndroid Build Coastguard Worker x = torch.arange(9.) 293*da0073e9SAndroid Build Coastguard Worker m_expected, e_expected = torch.frexp(x) 294*da0073e9SAndroid Build Coastguard Worker m_actual = torch.arange(9.) 295*da0073e9SAndroid Build Coastguard Worker e_actual = torch.zeros([9], dtype=torch.int32) 296*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 297*da0073e9SAndroid Build Coastguard Worker torch.frexp(x, out=(m_actual, e_actual)) 298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_expected, m_actual) 299*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e_expected, e_actual) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps Torch.tensor with aliasing outputs due to aliasing inputs 302*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self): 303*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 304*da0073e9SAndroid Build Coastguard Worker actual = torch.zeros(3) 305*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 306*da0073e9SAndroid Build Coastguard Worker torch.aminmax(x, dim=0, out=[actual, actual]) 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.amax(x, dim=0), actual) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input 310*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_device_input(self): 311*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 312*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), device="cpu", dtype=torch.double) 313*da0073e9SAndroid Build Coastguard Worker y = x + x 314*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x + x, y) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case 317*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_training_op(self): 318*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 319*da0073e9SAndroid Build Coastguard Worker batch = torch.nn.BatchNorm1d(3, track_running_stats=True) 320*da0073e9SAndroid Build Coastguard Worker expected = batch(x) 321*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 322*da0073e9SAndroid Build Coastguard Worker actual = batch(x) 323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case 326*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_functionality_nested_training_op(self): 327*da0073e9SAndroid Build Coastguard Worker actual = torch.rand((3, 3)) 328*da0073e9SAndroid Build Coastguard Worker batch = torch.nn.BatchNorm1d(3, track_running_stats=True) 329*da0073e9SAndroid Build Coastguard Worker expected = torch.clone(actual) 330*da0073e9SAndroid Build Coastguard Worker expected.sinh_() 331*da0073e9SAndroid Build Coastguard Worker expected.tanh_() 332*da0073e9SAndroid Build Coastguard Worker expected.relu_() 333*da0073e9SAndroid Build Coastguard Worker expected = batch(expected) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 336*da0073e9SAndroid Build Coastguard Worker actual.sinh_() 337*da0073e9SAndroid Build Coastguard Worker actual.tanh_() 338*da0073e9SAndroid Build Coastguard Worker actual.relu_() 339*da0073e9SAndroid Build Coastguard Worker actual = batch(actual) 340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaCheckMode wraps Torch.tensor with empty list input 343*da0073e9SAndroid Build Coastguard Worker def test_schema_check_mode_empty_list_input(self): 344*da0073e9SAndroid Build Coastguard Worker expected = torch.atleast_1d([]) 345*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 346*da0073e9SAndroid Build Coastguard Worker actual = torch.atleast_1d([]) 347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker # Tests that an exception is raised for a mismatching mutation 350*da0073e9SAndroid Build Coastguard Worker def test_mutation_check_fail(self): 351*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): 352*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 353*da0073e9SAndroid Build Coastguard Worker y = torch.rand((3, 3)) 354*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 355*da0073e9SAndroid Build Coastguard Worker IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y)) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker # # Tests that an exception is raised for a mismatching mutation over multiple ops 358*da0073e9SAndroid Build Coastguard Worker def test_mutation_check_fail_multiple_operators(self): 359*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): 360*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 361*da0073e9SAndroid Build Coastguard Worker y = torch.rand((3, 3)) 362*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 363*da0073e9SAndroid Build Coastguard Worker IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y)) 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker # Tests that an exception is raised for a mismatching alias 366*da0073e9SAndroid Build Coastguard Worker def test_alias_check_fail_simple(self): 367*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): 368*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 369*da0073e9SAndroid Build Coastguard Worker y = torch.rand((3, 3)) 370*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 371*da0073e9SAndroid Build Coastguard Worker IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker # Tests that an exception is raised for a mismatching alias over multiple ops 374*da0073e9SAndroid Build Coastguard Worker def test_alias_check_fail_multiple_operators(self): 375*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): 376*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 377*da0073e9SAndroid Build Coastguard Worker y = torch.zeros((3, 3), requires_grad=True) 378*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 379*da0073e9SAndroid Build Coastguard Worker IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker # Tests that an exception is raised for a centered mismatching alias over multiple ops 382*da0073e9SAndroid Build Coastguard Worker def test_alias_check_fail_multiple_operators_centered(self): 383*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): 384*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 385*da0073e9SAndroid Build Coastguard Worker y = torch.zeros((3, 3), requires_grad=True) 386*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 387*da0073e9SAndroid Build Coastguard Worker IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu() 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker # Tests that an exception is raised for a centered mismatching alias over multiple ops 390*da0073e9SAndroid Build Coastguard Worker def test_alias_check_fail_outputs_unexpectedly_aliasing(self): 391*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): 392*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 393*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as s: 394*da0073e9SAndroid Build Coastguard Worker IncorrectAliasTensor(x).aminmax(dim=0) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker # When this file was written, python op registration didn't exist. 397*da0073e9SAndroid Build Coastguard Worker # It's probably worth re-writing the entire file to use it, 398*da0073e9SAndroid Build Coastguard Worker # but instead I just added extra tests. 399*da0073e9SAndroid Build Coastguard Worker def test_alias_check_fail_custom_ops_secretly_aliasing(self): 400*da0073e9SAndroid Build Coastguard Worker def f(x): 401*da0073e9SAndroid Build Coastguard Worker return torch.ops.bad_schemas.secretly_aliasing(x) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 404*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"): 405*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as s: 406*da0073e9SAndroid Build Coastguard Worker out = f(x) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker def test_alias_check_fail_custom_ops_secretly_mutating(self): 409*da0073e9SAndroid Build Coastguard Worker def f(x): 410*da0073e9SAndroid Build Coastguard Worker return torch.ops.bad_schemas.secretly_mutating(x) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 413*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"): 414*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as s: 415*da0073e9SAndroid Build Coastguard Worker out = f(x) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker def test_alias_check_fail_custom_ops_output_is_input(self): 418*da0073e9SAndroid Build Coastguard Worker def f(x): 419*da0073e9SAndroid Build Coastguard Worker return torch.ops.bad_schemas.output_is_input(x) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 422*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"): 423*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode() as s: 424*da0073e9SAndroid Build Coastguard Worker out = f(x) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker # Tests that is_alias_of returns as expected 427*da0073e9SAndroid Build Coastguard Worker def test_is_alias_of_basic(self): 428*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 429*da0073e9SAndroid Build Coastguard Worker y = torch.rand((3, 3), requires_grad=True) 430*da0073e9SAndroid Build Coastguard Worker y = x.add(x, alpha=2) 431*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_alias_of(x, x)) 432*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_alias_of(x, y)) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker # Tests that is_alias_of returns as expected with empty containers 435*da0073e9SAndroid Build Coastguard Worker def test_is_alias_of_empty_container(self): 436*da0073e9SAndroid Build Coastguard Worker x = [] 437*da0073e9SAndroid Build Coastguard Worker y = torch.rand((3, 3), requires_grad=True) 438*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_alias_of(x, x)) 439*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_alias_of(x, y)) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker # Tests that overlaps returns as expected 442*da0073e9SAndroid Build Coastguard Worker def test_overlaps_basic(self): 443*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), requires_grad=True) 444*da0073e9SAndroid Build Coastguard Worker y = torch.rand((3, 3), requires_grad=True) 445*da0073e9SAndroid Build Coastguard Worker z = [x, y] 446*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._overlaps(x, x)) 447*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._overlaps(x, y)) 448*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._overlaps(z, x)) 449*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._overlaps(z, y)) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker # Tests that overlaps returns correctly with empty containers 452*da0073e9SAndroid Build Coastguard Worker def test_overlaps_empty_container(self): 453*da0073e9SAndroid Build Coastguard Worker x = [] 454*da0073e9SAndroid Build Coastguard Worker y = [torch.rand((3, 3), requires_grad=True)] 455*da0073e9SAndroid Build Coastguard Worker # Empty containers return false 456*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._overlaps(y, x)) 457*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._overlaps(y, y)) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker # Tests that SchemaInfo Bindings work as expected 460*da0073e9SAndroid Build Coastguard Worker def test_schema_info_bind_basic(self): 461*da0073e9SAndroid Build Coastguard Worker class SchemaInfoBindTestMode(TorchDispatchMode): 462*da0073e9SAndroid Build Coastguard Worker def __init__(self, test_self): 463*da0073e9SAndroid Build Coastguard Worker self.test_self = test_self 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 466*da0073e9SAndroid Build Coastguard Worker named_arg_list = normalize_function( 467*da0073e9SAndroid Build Coastguard Worker func, 468*da0073e9SAndroid Build Coastguard Worker args, 469*da0073e9SAndroid Build Coastguard Worker kwargs, 470*da0073e9SAndroid Build Coastguard Worker normalize_to_only_use_kwargs=True 471*da0073e9SAndroid Build Coastguard Worker ).kwargs 472*da0073e9SAndroid Build Coastguard Worker schema_info_value_test = torch._C._SchemaInfo(func._schema) 473*da0073e9SAndroid Build Coastguard Worker schema_info_values_test = torch._C._SchemaInfo(func._schema) 474*da0073e9SAndroid Build Coastguard Worker self.test_self.assertFalse(schema_info_value_test.may_alias( 475*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 476*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 477*da0073e9SAndroid Build Coastguard Worker self.test_self.assertFalse(schema_info_values_test.may_alias( 478*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 479*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 480*da0073e9SAndroid Build Coastguard Worker for i in named_arg_list: 481*da0073e9SAndroid Build Coastguard Worker schema_info_value_test.add_argument_value(i, named_arg_list[i]) 482*da0073e9SAndroid Build Coastguard Worker schema_info_values_test.add_argument_values(named_arg_list) 483*da0073e9SAndroid Build Coastguard Worker self.test_self.assertTrue(schema_info_value_test.may_alias( 484*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 485*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 486*da0073e9SAndroid Build Coastguard Worker self.test_self.assertTrue(schema_info_values_test.may_alias( 487*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), 488*da0073e9SAndroid Build Coastguard Worker torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 491*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3)) 492*da0073e9SAndroid Build Coastguard Worker with SchemaInfoBindTestMode(self) as schemaInfoCheck: 493*da0073e9SAndroid Build Coastguard Worker x.add(x) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Workerclass TestSchemaCheckModeOpInfo(JitTestCase): 497*da0073e9SAndroid Build Coastguard Worker @ops(op_db, dtypes=OpDTypes.supported) 498*da0073e9SAndroid Build Coastguard Worker def test_schema_correctness(self, device, dtype, op): 499*da0073e9SAndroid Build Coastguard Worker # Currently torch.equal isn't supported with torch.complex32 500*da0073e9SAndroid Build Coastguard Worker # There's also errors with complex64 and complex128 501*da0073e9SAndroid Build Coastguard Worker if (dtype == torch.complex32): 502*da0073e9SAndroid Build Coastguard Worker return 503*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device, dtype, requires_grad=False): 504*da0073e9SAndroid Build Coastguard Worker with SchemaCheckMode(): 505*da0073e9SAndroid Build Coastguard Worker op(sample.input, *sample.args, **sample.kwargs) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda")) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 510*da0073e9SAndroid Build Coastguard Worker run_tests() 511