1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport copy 5*da0073e9SAndroid Build Coastguard Workerimport inspect 6*da0073e9SAndroid Build Coastguard Workerimport itertools 7*da0073e9SAndroid Build Coastguard Workerimport os 8*da0073e9SAndroid Build Coastguard Workerimport re 9*da0073e9SAndroid Build Coastguard Workerimport unittest 10*da0073e9SAndroid Build Coastguard Workerimport warnings 11*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 12*da0073e9SAndroid Build Coastguard Workerfrom collections.abc import Sequence 13*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 14*da0073e9SAndroid Build Coastguard Workerfrom importlib import import_module 15*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerimport torch 18*da0073e9SAndroid Build Coastguard Workerimport torch._prims as prims 19*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 20*da0073e9SAndroid Build Coastguard Workerfrom torch._prims.context import TorchRefsMode 21*da0073e9SAndroid Build Coastguard Workerfrom torch._prims_common.wrappers import _maybe_remove_out_wrapper 22*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 23*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_utils import outputs_alias_inputs 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import composite_compliance, opinfo 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 27*da0073e9SAndroid Build Coastguard Worker deviceCountAtLeast, 28*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 29*da0073e9SAndroid Build Coastguard Worker onlyCPU, 30*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 31*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypesAnd, 32*da0073e9SAndroid Build Coastguard Worker OpDTypes, 33*da0073e9SAndroid Build Coastguard Worker ops, 34*da0073e9SAndroid Build Coastguard Worker skipMeta, 35*da0073e9SAndroid Build Coastguard Worker) 36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 37*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and, 38*da0073e9SAndroid Build Coastguard Worker floating_and_complex_types_and, 39*da0073e9SAndroid Build Coastguard Worker integral_types_and, 40*da0073e9SAndroid Build Coastguard Worker) 41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 42*da0073e9SAndroid Build Coastguard Worker BinaryUfuncInfo, 43*da0073e9SAndroid Build Coastguard Worker op_db, 44*da0073e9SAndroid Build Coastguard Worker ops_and_refs, 45*da0073e9SAndroid Build Coastguard Worker python_ref_db, 46*da0073e9SAndroid Build Coastguard Worker ReductionOpInfo, 47*da0073e9SAndroid Build Coastguard Worker ReductionPythonRefInfo, 48*da0073e9SAndroid Build Coastguard Worker skip, 49*da0073e9SAndroid Build Coastguard Worker skipOps, 50*da0073e9SAndroid Build Coastguard Worker SpectralFuncInfo, 51*da0073e9SAndroid Build Coastguard Worker UnaryUfuncInfo, 52*da0073e9SAndroid Build Coastguard Worker xfail, 53*da0073e9SAndroid Build Coastguard Worker) 54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 55*da0073e9SAndroid Build Coastguard Worker clone_input_helper, 56*da0073e9SAndroid Build Coastguard Worker first_sample, 57*da0073e9SAndroid Build Coastguard Worker IS_CI, 58*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 59*da0073e9SAndroid Build Coastguard Worker is_iterable_of_tensors, 60*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 61*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 62*da0073e9SAndroid Build Coastguard Worker noncontiguous_like, 63*da0073e9SAndroid Build Coastguard Worker parametrize, 64*da0073e9SAndroid Build Coastguard Worker run_tests, 65*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 66*da0073e9SAndroid Build Coastguard Worker skipIfTorchInductor, 67*da0073e9SAndroid Build Coastguard Worker slowTest, 68*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 69*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 70*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 71*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 72*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHINDUCTOR, 73*da0073e9SAndroid Build Coastguard Worker TEST_WITH_UBSAN, 74*da0073e9SAndroid Build Coastguard Worker TestCase, 75*da0073e9SAndroid Build Coastguard Worker unMarkDynamoStrictTest, 76*da0073e9SAndroid Build Coastguard Worker) 77*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 78*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() == torch.float32 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker# variant testing is only done with torch.float and torch.cfloat to avoid 84*da0073e9SAndroid Build Coastguard Worker# excessive test times and maximize signal to noise ratio 85*da0073e9SAndroid Build Coastguard Worker_variant_ops = partial( 86*da0073e9SAndroid Build Coastguard Worker ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat) 87*da0073e9SAndroid Build Coastguard Worker) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker# Get names of all the operators which have ref in their entry in OpInfo (testing infra) 90*da0073e9SAndroid Build Coastguard Worker# except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py), 91*da0073e9SAndroid Build Coastguard Worker# elementwise binary operators (separately implemented in test_binary_ufuncs.py), 92*da0073e9SAndroid Build Coastguard Worker# reduction operations (separately impelemented in test_reductions.py), 93*da0073e9SAndroid Build Coastguard Worker# and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) 94*da0073e9SAndroid Build Coastguard Worker_ref_test_ops = tuple( 95*da0073e9SAndroid Build Coastguard Worker filter( 96*da0073e9SAndroid Build Coastguard Worker lambda op: not isinstance( 97*da0073e9SAndroid Build Coastguard Worker op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo) 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker and op.ref is not None, 100*da0073e9SAndroid Build Coastguard Worker op_db, 101*da0073e9SAndroid Build Coastguard Worker ) 102*da0073e9SAndroid Build Coastguard Worker) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Workerdef reduction_dtype_filter(op): 106*da0073e9SAndroid Build Coastguard Worker if ( 107*da0073e9SAndroid Build Coastguard Worker not isinstance(op, ReductionPythonRefInfo) 108*da0073e9SAndroid Build Coastguard Worker or not op.supports_out 109*da0073e9SAndroid Build Coastguard Worker or torch.int16 not in op.dtypes 110*da0073e9SAndroid Build Coastguard Worker ): 111*da0073e9SAndroid Build Coastguard Worker return False 112*da0073e9SAndroid Build Coastguard Worker return "dtype" in inspect.getfullargspec(op.op).kwonlyargs 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker# Create a list of operators that are a subset of _ref_test_ops but don't have a 116*da0073e9SAndroid Build Coastguard Worker# numpy ref to compare them too, If both CPU and CUDA are compared to numpy 117*da0073e9SAndroid Build Coastguard Worker# then they do not need to be compared to each other 118*da0073e9SAndroid Build Coastguard Worker_ops_and_refs_with_no_numpy_ref = [op for op in ops_and_refs if op.ref is None] 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker# Tests that apply to all operators and aren't related to any particular 124*da0073e9SAndroid Build Coastguard Worker# system 125*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 126*da0073e9SAndroid Build Coastguard Workerclass TestCommon(TestCase): 127*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI 130*da0073e9SAndroid Build Coastguard Worker @classmethod 131*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 132*da0073e9SAndroid Build Coastguard Worker super().tearDownClass() 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker if IS_CI: 135*da0073e9SAndroid Build Coastguard Worker err_msg = ( 136*da0073e9SAndroid Build Coastguard Worker "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." 137*da0073e9SAndroid Build Coastguard Worker "This is OK for testing, but be sure to set the dtypes manually before landing your PR!" 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker # Assure no opinfo entry has dynamic_dtypes 140*da0073e9SAndroid Build Coastguard Worker filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db)) 141*da0073e9SAndroid Build Coastguard Worker for op in filtered_ops: 142*da0073e9SAndroid Build Coastguard Worker fmt_str = opinfo.utils.str_format_dynamic_dtype(op) 143*da0073e9SAndroid Build Coastguard Worker err_msg += "\n" + fmt_str 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker assert len(filtered_ops) == 0, err_msg 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker # Validates that each OpInfo works correctly on different CUDA devices 148*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 149*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 150*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float32, torch.long)) 151*da0073e9SAndroid Build Coastguard Worker def test_multiple_devices(self, devices, dtype, op): 152*da0073e9SAndroid Build Coastguard Worker for cuda_device_str in devices: 153*da0073e9SAndroid Build Coastguard Worker cuda_device = torch.device(cuda_device_str) 154*da0073e9SAndroid Build Coastguard Worker # NOTE: only tests on first sample 155*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(cuda_device, dtype) 156*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, samples) 157*da0073e9SAndroid Build Coastguard Worker result = op(sample.input, *sample.args, **sample.kwargs) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker if isinstance(result, torch.Tensor): 160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.device == cuda_device) 161*da0073e9SAndroid Build Coastguard Worker elif is_iterable_of_tensors(result): 162*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(t.device == cuda_device for t in result)) 163*da0073e9SAndroid Build Coastguard Worker else: 164*da0073e9SAndroid Build Coastguard Worker self.skipTest( 165*da0073e9SAndroid Build Coastguard Worker "Skipped! Only supports single tensor or iterable of tensor outputs." 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def test_pointwise_tag_coverage(self): 169*da0073e9SAndroid Build Coastguard Worker pytorch_dir = os.path.abspath(__file__ + "/../../") 170*da0073e9SAndroid Build Coastguard Worker files = [ 171*da0073e9SAndroid Build Coastguard Worker "aten/src/ATen/native/UnaryOps.cpp", 172*da0073e9SAndroid Build Coastguard Worker "aten/src/ATen/native/BinaryOps.cpp", 173*da0073e9SAndroid Build Coastguard Worker "aten/src/ATen/native/PointwiseOps.cpp", 174*da0073e9SAndroid Build Coastguard Worker "aten/src/ATen/native/TensorCompare.cpp", 175*da0073e9SAndroid Build Coastguard Worker ] 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker allowed_functions = ( 178*da0073e9SAndroid Build Coastguard Worker # reduction version of these operators 179*da0073e9SAndroid Build Coastguard Worker "aten.max.default", 180*da0073e9SAndroid Build Coastguard Worker "aten.max.dim", 181*da0073e9SAndroid Build Coastguard Worker "aten.max.dim_max", 182*da0073e9SAndroid Build Coastguard Worker "aten.max.names_dim", 183*da0073e9SAndroid Build Coastguard Worker "aten.max.names_dim_max", 184*da0073e9SAndroid Build Coastguard Worker "aten.max.unary_out", 185*da0073e9SAndroid Build Coastguard Worker "aten.min.default", 186*da0073e9SAndroid Build Coastguard Worker "aten.min.dim", 187*da0073e9SAndroid Build Coastguard Worker "aten.min.dim_min", 188*da0073e9SAndroid Build Coastguard Worker "aten.min.names_dim", 189*da0073e9SAndroid Build Coastguard Worker "aten.min.names_dim_min", 190*da0073e9SAndroid Build Coastguard Worker "aten.min.unary_out", 191*da0073e9SAndroid Build Coastguard Worker # not pointwise 192*da0073e9SAndroid Build Coastguard Worker "aten.isin.Tensor_Tensor", 193*da0073e9SAndroid Build Coastguard Worker "aten.isin.Tensor_Tensor_out", 194*da0073e9SAndroid Build Coastguard Worker "aten.isin.Tensor_Scalar", 195*da0073e9SAndroid Build Coastguard Worker "aten.isin.Tensor_Scalar_out", 196*da0073e9SAndroid Build Coastguard Worker "aten.isin.Scalar_Tensor", 197*da0073e9SAndroid Build Coastguard Worker "aten.isin.Scalar_Tensor_out", 198*da0073e9SAndroid Build Coastguard Worker "aten.mode.default", 199*da0073e9SAndroid Build Coastguard Worker "aten.mode.dimname", 200*da0073e9SAndroid Build Coastguard Worker "aten.mode.dimname_out", 201*da0073e9SAndroid Build Coastguard Worker "aten.mode.values", 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker regex = re.compile(r"DEFINE_DISPATCH\(.*_stub") 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def get_opoverloadpacket_from_dispatch(kernel): 207*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.ops.aten, kernel): 208*da0073e9SAndroid Build Coastguard Worker return kernel 209*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.ops.aten, f"__{kernel}__"): 210*da0073e9SAndroid Build Coastguard Worker return f"__{kernel}__" 211*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.ops.aten, f"special_{kernel}"): 212*da0073e9SAndroid Build Coastguard Worker return f"special_{kernel}" 213*da0073e9SAndroid Build Coastguard Worker if "_" in kernel: 214*da0073e9SAndroid Build Coastguard Worker kernel_split = kernel.split("_") 215*da0073e9SAndroid Build Coastguard Worker new_kernel = "_".join(kernel_split[:-1]) 216*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.ops.aten, new_kernel): 217*da0073e9SAndroid Build Coastguard Worker return new_kernel 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker # could not find op from kernel dispatch string 220*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker for file_name in files: 223*da0073e9SAndroid Build Coastguard Worker with open(os.path.join(pytorch_dir, file_name)) as f: 224*da0073e9SAndroid Build Coastguard Worker lines = f.read() 225*da0073e9SAndroid Build Coastguard Worker matches = regex.findall(lines) 226*da0073e9SAndroid Build Coastguard Worker for match in matches: 227*da0073e9SAndroid Build Coastguard Worker kernel = match[len("DEFINE_DISPATCH(") : -len("_stub")] 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker # no op definition for it, but defined with DEFINE_DISPATCH ? 230*da0073e9SAndroid Build Coastguard Worker if kernel == "trigamma": 231*da0073e9SAndroid Build Coastguard Worker continue 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker kernel = get_opoverloadpacket_from_dispatch(kernel) 234*da0073e9SAndroid Build Coastguard Worker overloadpacket = getattr(torch.ops.aten, kernel) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker for overload_name in overloadpacket.overloads(): 237*da0073e9SAndroid Build Coastguard Worker overload = getattr(overloadpacket, overload_name) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker if not torch._C._dispatch_has_kernel(overload.name()): 240*da0073e9SAndroid Build Coastguard Worker continue 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker # TODO: tags are not propagated to generated overload, 243*da0073e9SAndroid Build Coastguard Worker # and there's no way of specifying them 244*da0073e9SAndroid Build Coastguard Worker if torch.Tag.generated in overload.tags: 245*da0073e9SAndroid Build Coastguard Worker continue 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker if str(overload) in allowed_functions: 248*da0073e9SAndroid Build Coastguard Worker continue 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.Tag.pointwise in overload.tags) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker # Tests that the function and its (ndarray-accepting) reference produce the same 253*da0073e9SAndroid Build Coastguard Worker # values on the tensors from sample_inputs func for the corresponding op. 254*da0073e9SAndroid Build Coastguard Worker # This test runs in double and complex double precision because 255*da0073e9SAndroid Build Coastguard Worker # NumPy does computation internally using double precision for many functions 256*da0073e9SAndroid Build Coastguard Worker # resulting in possible equality check failures. 257*da0073e9SAndroid Build Coastguard Worker # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947 258*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 259*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 260*da0073e9SAndroid Build Coastguard Worker @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128)) 261*da0073e9SAndroid Build Coastguard Worker def test_numpy_ref(self, device, dtype, op): 262*da0073e9SAndroid Build Coastguard Worker if ( 263*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHINDUCTOR 264*da0073e9SAndroid Build Coastguard Worker and op.formatted_name 265*da0073e9SAndroid Build Coastguard Worker in ("signal_windows_exponential", "signal_windows_bartlett") 266*da0073e9SAndroid Build Coastguard Worker and dtype == torch.float64 267*da0073e9SAndroid Build Coastguard Worker and "cuda" in device 268*da0073e9SAndroid Build Coastguard Worker or "cpu" in device 269*da0073e9SAndroid Build Coastguard Worker ): # noqa: E121 270*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("XXX: raises tensor-likes are not close.") 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker # Sets the default dtype to NumPy's default dtype of double 273*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 274*da0073e9SAndroid Build Coastguard Worker for sample_input in op.reference_inputs(device, dtype): 275*da0073e9SAndroid Build Coastguard Worker self.compare_with_reference( 276*da0073e9SAndroid Build Coastguard Worker op, op.ref, sample_input, exact_dtype=(dtype is not torch.long) 277*da0073e9SAndroid Build Coastguard Worker ) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker # Tests that the cpu and gpu results are consistent 280*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 281*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 282*da0073e9SAndroid Build Coastguard Worker @slowTest 283*da0073e9SAndroid Build Coastguard Worker @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) 284*da0073e9SAndroid Build Coastguard Worker def test_compare_cpu(self, device, dtype, op): 285*da0073e9SAndroid Build Coastguard Worker def to_cpu(arg): 286*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor): 287*da0073e9SAndroid Build Coastguard Worker return arg.to(device="cpu") 288*da0073e9SAndroid Build Coastguard Worker return arg 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker samples = op.reference_inputs(device, dtype) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker for sample in samples: 293*da0073e9SAndroid Build Coastguard Worker cpu_sample = sample.transform(to_cpu) 294*da0073e9SAndroid Build Coastguard Worker cuda_results = op(sample.input, *sample.args, **sample.kwargs) 295*da0073e9SAndroid Build Coastguard Worker cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker # output_process_fn_grad has a very unfortunate name 298*da0073e9SAndroid Build Coastguard Worker # We use this function in linalg extensively to postprocess the inputs of functions 299*da0073e9SAndroid Build Coastguard Worker # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. 300*da0073e9SAndroid Build Coastguard Worker # CPU and CUDA implementations of the SVD can return valid SVDs that are different. 301*da0073e9SAndroid Build Coastguard Worker # We use this function to compare them. 302*da0073e9SAndroid Build Coastguard Worker cuda_results = sample.output_process_fn_grad(cuda_results) 303*da0073e9SAndroid Build Coastguard Worker cpu_results = cpu_sample.output_process_fn_grad(cpu_results) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker # Lower tolerance because we are running this as a `@slowTest` 306*da0073e9SAndroid Build Coastguard Worker # Don't want the periodic tests to fail frequently 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker # Tests that experimental Python References can propagate shape, dtype, 310*da0073e9SAndroid Build Coastguard Worker # and device metadata properly. 311*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/78050 for a discussion of stride propagation. 312*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 313*da0073e9SAndroid Build Coastguard Worker @ops(python_ref_db) 314*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("Takes too long for inductor") 315*da0073e9SAndroid Build Coastguard Worker def test_python_ref_meta(self, device, dtype, op): 316*da0073e9SAndroid Build Coastguard Worker CHECK_CONJ_SKIPS = { 317*da0073e9SAndroid Build Coastguard Worker torch._refs.linalg.svd, 318*da0073e9SAndroid Build Coastguard Worker } 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker with FakeTensorMode() as mode: 321*da0073e9SAndroid Build Coastguard Worker pass 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker def _to_tensormeta(x): 324*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor): 325*da0073e9SAndroid Build Coastguard Worker out = FakeTensor.from_tensor(x, mode) 326*da0073e9SAndroid Build Coastguard Worker return out 327*da0073e9SAndroid Build Coastguard Worker return x 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker # TODO: iterate over requires_grad true/false 330*da0073e9SAndroid Build Coastguard Worker for sample in op.reference_inputs(device, dtype, requires_grad=False): 331*da0073e9SAndroid Build Coastguard Worker result = op(sample.input, *sample.args, **sample.kwargs) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker meta_sample = sample.transform(_to_tensormeta) 334*da0073e9SAndroid Build Coastguard Worker try: 335*da0073e9SAndroid Build Coastguard Worker with mode: 336*da0073e9SAndroid Build Coastguard Worker meta_result = op( 337*da0073e9SAndroid Build Coastguard Worker meta_sample.input, *meta_sample.args, **meta_sample.kwargs 338*da0073e9SAndroid Build Coastguard Worker ) 339*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: 340*da0073e9SAndroid Build Coastguard Worker continue 341*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.DataDependentOutputException: 342*da0073e9SAndroid Build Coastguard Worker continue 343*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.UnsupportedOperatorException: 344*da0073e9SAndroid Build Coastguard Worker continue 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker if isinstance(result, torch.Tensor): 347*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(meta_result, FakeTensor)) 348*da0073e9SAndroid Build Coastguard Worker prims.utils.compare_tensor_meta( 349*da0073e9SAndroid Build Coastguard Worker result, meta_result, check_conj=op.op not in CHECK_CONJ_SKIPS 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker elif isinstance(result, Sequence): 352*da0073e9SAndroid Build Coastguard Worker for a, b in zip(result, meta_result): 353*da0073e9SAndroid Build Coastguard Worker if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): 354*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(b, FakeTensor)) 355*da0073e9SAndroid Build Coastguard Worker prims.utils.compare_tensor_meta( 356*da0073e9SAndroid Build Coastguard Worker a, b, check_conj=op.op not in CHECK_CONJ_SKIPS 357*da0073e9SAndroid Build Coastguard Worker ) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def _ref_test_helper( 360*da0073e9SAndroid Build Coastguard Worker self, 361*da0073e9SAndroid Build Coastguard Worker ctx, 362*da0073e9SAndroid Build Coastguard Worker device, 363*da0073e9SAndroid Build Coastguard Worker dtype, 364*da0073e9SAndroid Build Coastguard Worker op, 365*da0073e9SAndroid Build Coastguard Worker skip_zero_numel=False, 366*da0073e9SAndroid Build Coastguard Worker skip_zero_dim=False, 367*da0073e9SAndroid Build Coastguard Worker skip_bfloat=False, 368*da0073e9SAndroid Build Coastguard Worker skip_view_consistency=False, 369*da0073e9SAndroid Build Coastguard Worker ): 370*da0073e9SAndroid Build Coastguard Worker # NOTE: this test works by comparing the reference 371*da0073e9SAndroid Build Coastguard Worker ex = None 372*da0073e9SAndroid Build Coastguard Worker for sample in op.reference_inputs(device, dtype, requires_grad=False): 373*da0073e9SAndroid Build Coastguard Worker if ( 374*da0073e9SAndroid Build Coastguard Worker isinstance(sample.input, torch.Tensor) 375*da0073e9SAndroid Build Coastguard Worker and sample.input.numel() == 0 376*da0073e9SAndroid Build Coastguard Worker and skip_zero_numel 377*da0073e9SAndroid Build Coastguard Worker ): 378*da0073e9SAndroid Build Coastguard Worker continue 379*da0073e9SAndroid Build Coastguard Worker if ( 380*da0073e9SAndroid Build Coastguard Worker isinstance(sample.input, torch.Tensor) 381*da0073e9SAndroid Build Coastguard Worker and sample.input.ndim == 0 382*da0073e9SAndroid Build Coastguard Worker and skip_zero_dim 383*da0073e9SAndroid Build Coastguard Worker ): 384*da0073e9SAndroid Build Coastguard Worker continue 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker if skip_bfloat and ( 387*da0073e9SAndroid Build Coastguard Worker ( 388*da0073e9SAndroid Build Coastguard Worker isinstance(sample.input, torch.Tensor) 389*da0073e9SAndroid Build Coastguard Worker and sample.input.dtype == torch.bfloat16 390*da0073e9SAndroid Build Coastguard Worker ) 391*da0073e9SAndroid Build Coastguard Worker or any( 392*da0073e9SAndroid Build Coastguard Worker isinstance(arg, torch.Tensor) and arg.dtype == torch.bfloat16 393*da0073e9SAndroid Build Coastguard Worker for arg in sample.args 394*da0073e9SAndroid Build Coastguard Worker ) 395*da0073e9SAndroid Build Coastguard Worker ): 396*da0073e9SAndroid Build Coastguard Worker continue 397*da0073e9SAndroid Build Coastguard Worker with ctx(): 398*da0073e9SAndroid Build Coastguard Worker ref_result = op(sample.input, *sample.args, **sample.kwargs) 399*da0073e9SAndroid Build Coastguard Worker torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker for a, b in zip( 402*da0073e9SAndroid Build Coastguard Worker pytree.tree_leaves(ref_result), pytree.tree_leaves(torch_result) 403*da0073e9SAndroid Build Coastguard Worker ): 404*da0073e9SAndroid Build Coastguard Worker if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor): 405*da0073e9SAndroid Build Coastguard Worker prims.utils.compare_tensor_meta(a, b) 406*da0073e9SAndroid Build Coastguard Worker if ( 407*da0073e9SAndroid Build Coastguard Worker getattr(op, "validate_view_consistency", True) 408*da0073e9SAndroid Build Coastguard Worker and not skip_view_consistency 409*da0073e9SAndroid Build Coastguard Worker ): 410*da0073e9SAndroid Build Coastguard Worker msg = ( 411*da0073e9SAndroid Build Coastguard Worker f"The torch implementation {'returns' if b._is_view() else 'does not return'} " 412*da0073e9SAndroid Build Coastguard Worker f"a view, while the reference {'does' if a._is_view() else 'does not'}" 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a._is_view(), b._is_view(), msg) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker # Computes the dtype the more precise computatino would occur in 417*da0073e9SAndroid Build Coastguard Worker precise_dtype = torch.bool 418*da0073e9SAndroid Build Coastguard Worker if prims.utils.is_integer_dtype(dtype): 419*da0073e9SAndroid Build Coastguard Worker # Note: bool and integer dtypes do not have more 420*da0073e9SAndroid Build Coastguard Worker # precise dtypes -- they simply must be close 421*da0073e9SAndroid Build Coastguard Worker precise_dtype = dtype 422*da0073e9SAndroid Build Coastguard Worker if prims.utils.is_float_dtype(dtype): 423*da0073e9SAndroid Build Coastguard Worker precise_dtype = torch.double 424*da0073e9SAndroid Build Coastguard Worker if prims.utils.is_complex_dtype(dtype): 425*da0073e9SAndroid Build Coastguard Worker precise_dtype = torch.cdouble 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker # Checks if the results are close 428*da0073e9SAndroid Build Coastguard Worker try: 429*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 430*da0073e9SAndroid Build Coastguard Worker ref_result, 431*da0073e9SAndroid Build Coastguard Worker torch_result, 432*da0073e9SAndroid Build Coastguard Worker exact_stride=False, 433*da0073e9SAndroid Build Coastguard Worker exact_device=True, 434*da0073e9SAndroid Build Coastguard Worker exact_layout=True, 435*da0073e9SAndroid Build Coastguard Worker exact_is_coalesced=True, 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 438*da0073e9SAndroid Build Coastguard Worker # Raises the error if the precise dtype comparison wouldn't be 439*da0073e9SAndroid Build Coastguard Worker # different 440*da0073e9SAndroid Build Coastguard Worker if dtype is precise_dtype: 441*da0073e9SAndroid Build Coastguard Worker raise e 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker ex = e 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker # Goes to next sample if these results are close 446*da0073e9SAndroid Build Coastguard Worker if not ex: 447*da0073e9SAndroid Build Coastguard Worker continue 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker # If the results are not close, checks that the 450*da0073e9SAndroid Build Coastguard Worker # reference is more accurate than the torch op 451*da0073e9SAndroid Build Coastguard Worker def _make_precise(x): 452*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.dtype): 453*da0073e9SAndroid Build Coastguard Worker return precise_dtype 454*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and x.dtype is dtype: 455*da0073e9SAndroid Build Coastguard Worker return x.to(precise_dtype) 456*da0073e9SAndroid Build Coastguard Worker return x 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker precise_sample = sample.transform(_make_precise) 459*da0073e9SAndroid Build Coastguard Worker precise_result = op.torch_opinfo( 460*da0073e9SAndroid Build Coastguard Worker precise_sample.input, *precise_sample.args, **precise_sample.kwargs 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker def _distance(a, b): 464*da0073e9SAndroid Build Coastguard Worker # Special-cases boolean comparisons 465*da0073e9SAndroid Build Coastguard Worker if prims.utils.is_boolean_dtype(a.dtype): 466*da0073e9SAndroid Build Coastguard Worker assert b.dtype is torch.bool 467*da0073e9SAndroid Build Coastguard Worker return (a ^ b).sum() 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker same = a == b 470*da0073e9SAndroid Build Coastguard Worker if prims.utils.is_float_dtype(a.dtype) or prims.utils.is_complex_dtype( 471*da0073e9SAndroid Build Coastguard Worker a.dtype 472*da0073e9SAndroid Build Coastguard Worker ): 473*da0073e9SAndroid Build Coastguard Worker same = torch.logical_or( 474*da0073e9SAndroid Build Coastguard Worker same, torch.logical_and(torch.isnan(a), torch.isnan(b)) 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker actual_error = torch.where(same, 0, torch.abs(a - b)).sum() 478*da0073e9SAndroid Build Coastguard Worker return actual_error 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker ref_distance = 0 481*da0073e9SAndroid Build Coastguard Worker for a, b in zip( 482*da0073e9SAndroid Build Coastguard Worker pytree.tree_leaves(ref_result), pytree.tree_leaves(precise_result) 483*da0073e9SAndroid Build Coastguard Worker ): 484*da0073e9SAndroid Build Coastguard Worker ref_distance = ref_distance + _distance(a, b) 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker torch_distance = 0 487*da0073e9SAndroid Build Coastguard Worker for a, b in zip( 488*da0073e9SAndroid Build Coastguard Worker pytree.tree_leaves(torch_result), pytree.tree_leaves(precise_result) 489*da0073e9SAndroid Build Coastguard Worker ): 490*da0073e9SAndroid Build Coastguard Worker torch_distance = torch_distance + _distance(a, b) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker # TODO: consider adding some tolerance to this comparison 493*da0073e9SAndroid Build Coastguard Worker msg = ( 494*da0073e9SAndroid Build Coastguard Worker f"Reference result was farther ({ref_distance}) from the precise " 495*da0073e9SAndroid Build Coastguard Worker f"computation than the torch result was ({torch_distance})!" 496*da0073e9SAndroid Build Coastguard Worker ) 497*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_distance <= torch_distance, msg=msg) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker # Reports numerical accuracy discrepancies 500*da0073e9SAndroid Build Coastguard Worker if ex is not None: 501*da0073e9SAndroid Build Coastguard Worker msg = "Test passed because the reference was more accurate than the torch operator." 502*da0073e9SAndroid Build Coastguard Worker warnings.warn(msg) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker # Tests that experimental Python References perform the same computation 505*da0073e9SAndroid Build Coastguard Worker # as the operators they reference, when operator calls in the torch 506*da0073e9SAndroid Build Coastguard Worker # namesapce are remapped to the refs namespace (torch.foo becomes refs.foo). 507*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 508*da0073e9SAndroid Build Coastguard Worker @ops(python_ref_db) 509*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("Takes too long for inductor") 510*da0073e9SAndroid Build Coastguard Worker def test_python_ref(self, device, dtype, op): 511*da0073e9SAndroid Build Coastguard Worker # In this test, primTorch refs call into the refs namespace 512*da0073e9SAndroid Build Coastguard Worker # For example, a ref with torch.foo in it will calls refs.foo instead 513*da0073e9SAndroid Build Coastguard Worker # Direct calls to refs and prims are not affected 514*da0073e9SAndroid Build Coastguard Worker if ( 515*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM 516*da0073e9SAndroid Build Coastguard Worker and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2") 517*da0073e9SAndroid Build Coastguard Worker and dtype == torch.float16 518*da0073e9SAndroid Build Coastguard Worker ): 519*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped on ROCm") 520*da0073e9SAndroid Build Coastguard Worker self._ref_test_helper(lambda: TorchRefsMode(strict=True), device, dtype, op) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker # Tests that experimental Python References perform the same computation 523*da0073e9SAndroid Build Coastguard Worker # as the operators they reference, when operator calls in the torch 524*da0073e9SAndroid Build Coastguard Worker # namespace are preserved (torch.foo remains torch.foo). 525*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 526*da0073e9SAndroid Build Coastguard Worker @ops(python_ref_db) 527*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("Takes too long for inductor") 528*da0073e9SAndroid Build Coastguard Worker def test_python_ref_torch_fallback(self, device, dtype, op): 529*da0073e9SAndroid Build Coastguard Worker # In this test, refs call into the torch namespace (after the initial invocation) 530*da0073e9SAndroid Build Coastguard Worker # For example, a ref with torch.foo in it will call torch.foo instead of refs.foo 531*da0073e9SAndroid Build Coastguard Worker # Direct calls to refs and prims are not translated 532*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and op.name == "_refs.fft.ihfftn" and dtype == torch.float16: 533*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped on ROCm") 534*da0073e9SAndroid Build Coastguard Worker self._ref_test_helper(contextlib.nullcontext, device, dtype, op) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 537*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 538*da0073e9SAndroid Build Coastguard Worker @ops(python_ref_db) 539*da0073e9SAndroid Build Coastguard Worker @parametrize("executor", ["aten"]) 540*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("Takes too long for inductor") 541*da0073e9SAndroid Build Coastguard Worker def test_python_ref_executor(self, device, dtype, op, executor): 542*da0073e9SAndroid Build Coastguard Worker if ( 543*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM 544*da0073e9SAndroid Build Coastguard Worker and (op.name == "_refs.fft.ihfftn" or op.name == "_refs.fft.ihfft2") 545*da0073e9SAndroid Build Coastguard Worker and dtype == torch.float16 546*da0073e9SAndroid Build Coastguard Worker ): 547*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped on ROCm") 548*da0073e9SAndroid Build Coastguard Worker # skip zero-dim tensors for some composites of reduction operations and view 549*da0073e9SAndroid Build Coastguard Worker skip_zero_dim_ops = [ 550*da0073e9SAndroid Build Coastguard Worker "_refs.logsumexp", 551*da0073e9SAndroid Build Coastguard Worker "_refs.log_softmax", 552*da0073e9SAndroid Build Coastguard Worker "_refs.native_group_norm", 553*da0073e9SAndroid Build Coastguard Worker "_refs.softmax", 554*da0073e9SAndroid Build Coastguard Worker "_refs.sum_to_size", 555*da0073e9SAndroid Build Coastguard Worker "ops.nvprims.view", 556*da0073e9SAndroid Build Coastguard Worker ] 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker from copy import copy 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker from torch._prims.executor import make_traced 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker op = copy(op) 563*da0073e9SAndroid Build Coastguard Worker op.op = partial(make_traced(op.op), executor=executor) 564*da0073e9SAndroid Build Coastguard Worker self._ref_test_helper(contextlib.nullcontext, device, dtype, op) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker @skipMeta 567*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 568*da0073e9SAndroid Build Coastguard Worker @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) 569*da0073e9SAndroid Build Coastguard Worker def test_errors(self, device, op): 570*da0073e9SAndroid Build Coastguard Worker error_inputs = op.error_inputs(device) 571*da0073e9SAndroid Build Coastguard Worker for ei in error_inputs: 572*da0073e9SAndroid Build Coastguard Worker si = ei.sample_input 573*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ei.error_type, ei.error_regex): 574*da0073e9SAndroid Build Coastguard Worker out = op(si.input, *si.args, **si.kwargs) 575*da0073e9SAndroid Build Coastguard Worker self.assertFalse(isinstance(out, type(NotImplemented))) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker @skipMeta 578*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 579*da0073e9SAndroid Build Coastguard Worker @ops( 580*da0073e9SAndroid Build Coastguard Worker [op for op in op_db if op.error_inputs_sparse_func is not None], 581*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.none, 582*da0073e9SAndroid Build Coastguard Worker ) 583*da0073e9SAndroid Build Coastguard Worker @parametrize( 584*da0073e9SAndroid Build Coastguard Worker "layout", 585*da0073e9SAndroid Build Coastguard Worker ( 586*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 587*da0073e9SAndroid Build Coastguard Worker torch.sparse_csc, 588*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsr, 589*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsc, 590*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo, 591*da0073e9SAndroid Build Coastguard Worker ), 592*da0073e9SAndroid Build Coastguard Worker ) 593*da0073e9SAndroid Build Coastguard Worker def test_errors_sparse(self, device, op, layout): 594*da0073e9SAndroid Build Coastguard Worker for ei in op.error_inputs_sparse(device, layout): 595*da0073e9SAndroid Build Coastguard Worker si = ei.sample_input 596*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ei.error_type, ei.error_regex): 597*da0073e9SAndroid Build Coastguard Worker out = op(si.input, *si.args, **si.kwargs) 598*da0073e9SAndroid Build Coastguard Worker self.assertFalse(isinstance(out, type(NotImplemented))) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker @skipMeta 601*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 602*da0073e9SAndroid Build Coastguard Worker @ops( 603*da0073e9SAndroid Build Coastguard Worker [op for op in python_ref_db if op.error_inputs_func is not None], 604*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.none, 605*da0073e9SAndroid Build Coastguard Worker ) 606*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("Takes too long for inductor") 607*da0073e9SAndroid Build Coastguard Worker def test_python_ref_errors(self, device, op): 608*da0073e9SAndroid Build Coastguard Worker mode = FakeTensorMode() 609*da0073e9SAndroid Build Coastguard Worker with mode: 610*da0073e9SAndroid Build Coastguard Worker pass 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker def _to_tensormeta(x): 613*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor): 614*da0073e9SAndroid Build Coastguard Worker return FakeTensor.from_tensor(x, mode) 615*da0073e9SAndroid Build Coastguard Worker return x 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker error_inputs = op.error_inputs(device) 618*da0073e9SAndroid Build Coastguard Worker for ei in error_inputs: 619*da0073e9SAndroid Build Coastguard Worker si = ei.sample_input 620*da0073e9SAndroid Build Coastguard Worker meta_sample = si.transform(_to_tensormeta) 621*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ei.error_type, ei.error_regex): 622*da0073e9SAndroid Build Coastguard Worker op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker # Tests that the function produces the same result when called with 625*da0073e9SAndroid Build Coastguard Worker # noncontiguous tensors. 626*da0073e9SAndroid Build Coastguard Worker # TODO: get working with Windows by addressing failing operators 627*da0073e9SAndroid Build Coastguard Worker # TODO: get working with ASAN by addressing failing operators 628*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped under Windows") 629*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 630*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 631*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float32, torch.long, torch.complex64)) 632*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_samples(self, device, dtype, op): 633*da0073e9SAndroid Build Coastguard Worker test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type) 634*da0073e9SAndroid Build Coastguard Worker sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad) 635*da0073e9SAndroid Build Coastguard Worker for sample_input in sample_inputs: 636*da0073e9SAndroid Build Coastguard Worker t_inp, t_args, t_kwargs = ( 637*da0073e9SAndroid Build Coastguard Worker sample_input.input, 638*da0073e9SAndroid Build Coastguard Worker sample_input.args, 639*da0073e9SAndroid Build Coastguard Worker sample_input.kwargs, 640*da0073e9SAndroid Build Coastguard Worker ) 641*da0073e9SAndroid Build Coastguard Worker noncontig_sample = sample_input.noncontiguous() 642*da0073e9SAndroid Build Coastguard Worker n_inp, n_args, n_kwargs = ( 643*da0073e9SAndroid Build Coastguard Worker noncontig_sample.input, 644*da0073e9SAndroid Build Coastguard Worker noncontig_sample.args, 645*da0073e9SAndroid Build Coastguard Worker noncontig_sample.kwargs, 646*da0073e9SAndroid Build Coastguard Worker ) 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker # validates forward 649*da0073e9SAndroid Build Coastguard Worker expected = op(t_inp, *t_args, **t_kwargs) 650*da0073e9SAndroid Build Coastguard Worker actual = op(n_inp, *n_args, **n_kwargs) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker # Validate backward 655*da0073e9SAndroid Build Coastguard Worker # Short-circuits if the op doesn't support grad in this device x dtype 656*da0073e9SAndroid Build Coastguard Worker if not test_grad: 657*da0073e9SAndroid Build Coastguard Worker continue 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker expected = sample_input.output_process_fn_grad(expected) 660*da0073e9SAndroid Build Coastguard Worker actual = sample_input.output_process_fn_grad(actual) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker if isinstance(expected, torch.Tensor): 663*da0073e9SAndroid Build Coastguard Worker grad_for_expected = torch.randn_like(expected) 664*da0073e9SAndroid Build Coastguard Worker grad_for_actual = noncontiguous_like(grad_for_expected) 665*da0073e9SAndroid Build Coastguard Worker elif isinstance(expected, Sequence): 666*da0073e9SAndroid Build Coastguard Worker # Filter output elements that do not require grad 667*da0073e9SAndroid Build Coastguard Worker expected = [ 668*da0073e9SAndroid Build Coastguard Worker t 669*da0073e9SAndroid Build Coastguard Worker for t in expected 670*da0073e9SAndroid Build Coastguard Worker if isinstance(t, torch.Tensor) and t.requires_grad 671*da0073e9SAndroid Build Coastguard Worker ] 672*da0073e9SAndroid Build Coastguard Worker actual = [ 673*da0073e9SAndroid Build Coastguard Worker n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad 674*da0073e9SAndroid Build Coastguard Worker ] 675*da0073e9SAndroid Build Coastguard Worker grad_for_expected = [torch.randn_like(t) for t in expected] 676*da0073e9SAndroid Build Coastguard Worker grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected] 677*da0073e9SAndroid Build Coastguard Worker else: 678*da0073e9SAndroid Build Coastguard Worker # Nothing to do if it returns a scalar or things like that 679*da0073e9SAndroid Build Coastguard Worker continue 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker # Concatenate inputs into a tuple 682*da0073e9SAndroid Build Coastguard Worker t_inputs = ( 683*da0073e9SAndroid Build Coastguard Worker (t_inp,) + t_args 684*da0073e9SAndroid Build Coastguard Worker if isinstance(t_inp, torch.Tensor) 685*da0073e9SAndroid Build Coastguard Worker else tuple(t_inp) + t_args 686*da0073e9SAndroid Build Coastguard Worker ) 687*da0073e9SAndroid Build Coastguard Worker n_inputs = ( 688*da0073e9SAndroid Build Coastguard Worker (n_inp,) + n_args 689*da0073e9SAndroid Build Coastguard Worker if isinstance(n_inp, torch.Tensor) 690*da0073e9SAndroid Build Coastguard Worker else tuple(n_inp) + n_args 691*da0073e9SAndroid Build Coastguard Worker ) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker # Filter the elemnts that are tensors that require grad 694*da0073e9SAndroid Build Coastguard Worker t_input_tensors = [ 695*da0073e9SAndroid Build Coastguard Worker t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad 696*da0073e9SAndroid Build Coastguard Worker ] 697*da0073e9SAndroid Build Coastguard Worker n_input_tensors = [ 698*da0073e9SAndroid Build Coastguard Worker n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad 699*da0073e9SAndroid Build Coastguard Worker ] 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(t_input_tensors), len(n_input_tensors)) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker # Some functions may not use all the inputs to generate gradients. One of the 704*da0073e9SAndroid Build Coastguard Worker # few examples of this "odd" behaviour is F.hinge_embedding_loss 705*da0073e9SAndroid Build Coastguard Worker t_grads = torch.autograd.grad( 706*da0073e9SAndroid Build Coastguard Worker expected, t_input_tensors, grad_for_expected, allow_unused=True 707*da0073e9SAndroid Build Coastguard Worker ) 708*da0073e9SAndroid Build Coastguard Worker n_grads = torch.autograd.grad( 709*da0073e9SAndroid Build Coastguard Worker actual, n_input_tensors, grad_for_actual, allow_unused=True 710*da0073e9SAndroid Build Coastguard Worker ) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}." 713*da0073e9SAndroid Build Coastguard Worker for i, (t, n) in enumerate(zip(t_grads, n_grads)): 714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, n, msg=msg.format(i)) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker # Separates one case from the following test_out because many ops don't properly implement the 717*da0073e9SAndroid Build Coastguard Worker # incorrectly sized out parameter warning properly yet 718*da0073e9SAndroid Build Coastguard Worker # Cases test here: 719*da0073e9SAndroid Build Coastguard Worker # - out= with the correct dtype and device, but the wrong shape 720*da0073e9SAndroid Build Coastguard Worker @ops(ops_and_refs, dtypes=OpDTypes.none) 721*da0073e9SAndroid Build Coastguard Worker def test_out_warning(self, device, op): 722*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_TORCHDYNAMO and op.name == "_refs.clamp": 723*da0073e9SAndroid Build Coastguard Worker self.skipTest("flaky") 724*da0073e9SAndroid Build Coastguard Worker # Prefers running in float32 but has a fallback for the first listed supported dtype 725*da0073e9SAndroid Build Coastguard Worker supported_dtypes = op.supported_dtypes(self.device_type) 726*da0073e9SAndroid Build Coastguard Worker if len(supported_dtypes) == 0: 727*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skipped! Op has not supported dtypes on this device.") 728*da0073e9SAndroid Build Coastguard Worker dtype = ( 729*da0073e9SAndroid Build Coastguard Worker torch.float32 730*da0073e9SAndroid Build Coastguard Worker if torch.float32 in supported_dtypes 731*da0073e9SAndroid Build Coastguard Worker else next(iter(supported_dtypes)) 732*da0073e9SAndroid Build Coastguard Worker ) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker # Ops from python_ref_db point to python decomps that are potentially 735*da0073e9SAndroid Build Coastguard Worker # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these 736*da0073e9SAndroid Build Coastguard Worker # ops before testing to avoid clashing with OpInfo.supports_out 737*da0073e9SAndroid Build Coastguard Worker if not op.supports_out: 738*da0073e9SAndroid Build Coastguard Worker op = copy.copy(op) 739*da0073e9SAndroid Build Coastguard Worker op.op = _maybe_remove_out_wrapper(op.op) 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 742*da0073e9SAndroid Build Coastguard Worker for sample in samples: 743*da0073e9SAndroid Build Coastguard Worker # calls it normally to get the expected result 744*da0073e9SAndroid Build Coastguard Worker expected = op(sample.input, *sample.args, **sample.kwargs) 745*da0073e9SAndroid Build Coastguard Worker op_out = partial(op, sample.input, *sample.args, **sample.kwargs) 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker # Short-circuits if output is not a single tensor or an 748*da0073e9SAndroid Build Coastguard Worker # iterable of tensors 749*da0073e9SAndroid Build Coastguard Worker if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( 750*da0073e9SAndroid Build Coastguard Worker expected, include_empty=True 751*da0073e9SAndroid Build Coastguard Worker ): 752*da0073e9SAndroid Build Coastguard Worker self.skipTest( 753*da0073e9SAndroid Build Coastguard Worker "Skipped! Only supports single tensor or iterable of tensor outputs." 754*da0073e9SAndroid Build Coastguard Worker ) 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker # Validates the op doesn't support out if it claims not to 757*da0073e9SAndroid Build Coastguard Worker if not op.supports_out: 758*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 759*da0073e9SAndroid Build Coastguard Worker assert op_out(out=expected) != NotImplemented 760*da0073e9SAndroid Build Coastguard Worker return 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker # A wrapper around map that works with single tensors and always 763*da0073e9SAndroid Build Coastguard Worker # instantiates the map. Used below to apply transforms to 764*da0073e9SAndroid Build Coastguard Worker # single tensor and iterable tensor outputs. 765*da0073e9SAndroid Build Coastguard Worker def _apply_out_transform(fn, out): 766*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 767*da0073e9SAndroid Build Coastguard Worker return fn(out) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker # assumes (see above) that out is an iterable of tensors 770*da0073e9SAndroid Build Coastguard Worker return tuple(map(fn, out)) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker # Extracts strides from a tensor or iterable of tensors into a tuple 773*da0073e9SAndroid Build Coastguard Worker def _extract_strides(out): 774*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 775*da0073e9SAndroid Build Coastguard Worker return (out.stride(),) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker # assumes (see above) that out is an iterable of tensors 778*da0073e9SAndroid Build Coastguard Worker return tuple(t.stride() for t in out) 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker # Extracts data pointers from a tensor or iterable of tensors into a tuple 781*da0073e9SAndroid Build Coastguard Worker # NOTE: only extracts on the CPU and CUDA device types since some 782*da0073e9SAndroid Build Coastguard Worker # device types don't have storage 783*da0073e9SAndroid Build Coastguard Worker def _extract_data_ptrs(out): 784*da0073e9SAndroid Build Coastguard Worker if self.device_type != "cpu" and self.device_type != "cuda": 785*da0073e9SAndroid Build Coastguard Worker return () 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 788*da0073e9SAndroid Build Coastguard Worker return (out.data_ptr(),) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker # assumes (see above) that out is an iterable of tensors 791*da0073e9SAndroid Build Coastguard Worker return tuple(t.data_ptr() for t in out) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 794*da0073e9SAndroid Build Coastguard Worker def _compare_out(transform, *, compare_strides_and_data_ptrs=True): 795*da0073e9SAndroid Build Coastguard Worker out = _apply_out_transform(transform, expected) 796*da0073e9SAndroid Build Coastguard Worker original_strides = _extract_strides(out) 797*da0073e9SAndroid Build Coastguard Worker original_ptrs = _extract_data_ptrs(out) 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 800*da0073e9SAndroid Build Coastguard Worker final_strides = _extract_strides(out) 801*da0073e9SAndroid Build Coastguard Worker final_ptrs = _extract_data_ptrs(out) 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker if compare_strides_and_data_ptrs: 806*da0073e9SAndroid Build Coastguard Worker stride_msg = ( 807*da0073e9SAndroid Build Coastguard Worker f"Strides are not the same! Original strides were {original_strides} " 808*da0073e9SAndroid Build Coastguard Worker f"and strides are now {final_strides}" 809*da0073e9SAndroid Build Coastguard Worker ) 810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(original_strides, final_strides, msg=stride_msg) 811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(original_ptrs, final_ptrs) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker # Case Zero: out= with the correct dtype and device, but the wrong shape 814*da0073e9SAndroid Build Coastguard Worker # Expected behavior: if nonempty, resize with a warning. 815*da0073e9SAndroid Build Coastguard Worker def _case_zero_transform(t): 816*da0073e9SAndroid Build Coastguard Worker wrong_shape = list(t.shape) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker if len(wrong_shape) == 0: 819*da0073e9SAndroid Build Coastguard Worker # Handles scalar tensor case (empty list) 820*da0073e9SAndroid Build Coastguard Worker wrong_shape = [2] 821*da0073e9SAndroid Build Coastguard Worker else: 822*da0073e9SAndroid Build Coastguard Worker wrong_shape[-1] = wrong_shape[-1] + 1 823*da0073e9SAndroid Build Coastguard Worker return make_tensor(wrong_shape, dtype=t.dtype, device=t.device) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker # Verifies the out values are correct 826*da0073e9SAndroid Build Coastguard Worker _compare_out(_case_zero_transform, compare_strides_and_data_ptrs=False) 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker # Additionally validates that the appropriate warning is thrown if a nonempty 829*da0073e9SAndroid Build Coastguard Worker # tensor is resized. 830*da0073e9SAndroid Build Coastguard Worker def _any_nonempty(out): 831*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 832*da0073e9SAndroid Build Coastguard Worker return out.numel() > 0 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker return any(x.numel() > 0 for x in out) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker out = _apply_out_transform(_case_zero_transform, expected) 837*da0073e9SAndroid Build Coastguard Worker msg_fail = "Resized a non-empty tensor but did not warn about it." 838*da0073e9SAndroid Build Coastguard Worker if _any_nonempty(out): 839*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 840*da0073e9SAndroid Build Coastguard Worker UserWarning, "An output with one or more elements", msg=msg_fail 841*da0073e9SAndroid Build Coastguard Worker ): 842*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker # Validates ops implement the correct out= behavior 845*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch 846*da0073e9SAndroid Build Coastguard Worker # for a description of the correct behavior 847*da0073e9SAndroid Build Coastguard Worker # Validates the following cases: 848*da0073e9SAndroid Build Coastguard Worker # - Case 0: out has the correct shape, dtype, and device but is full of extremal values 849*da0073e9SAndroid Build Coastguard Worker # - Case 1: out has the correct shape, dtype, and device but is noncontiguous 850*da0073e9SAndroid Build Coastguard Worker # - Case 2: out has the correct dtype and device, but is zero elements 851*da0073e9SAndroid Build Coastguard Worker # - Case 3: out has the correct shape and dtype, but is on a different device type 852*da0073e9SAndroid Build Coastguard Worker # - Case 4: out has the correct shape and device, but a dtype that cannot 853*da0073e9SAndroid Build Coastguard Worker # "safely" cast to 854*da0073e9SAndroid Build Coastguard Worker # 855*da0073e9SAndroid Build Coastguard Worker # Case 3 and 4 are slightly different when the op is a factory function: 856*da0073e9SAndroid Build Coastguard Worker # - if device, dtype are NOT passed, any combination of dtype/device should be OK for out 857*da0073e9SAndroid Build Coastguard Worker # - if device, dtype are passed, device and dtype should match 858*da0073e9SAndroid Build Coastguard Worker @ops(ops_and_refs, dtypes=OpDTypes.any_one) 859*da0073e9SAndroid Build Coastguard Worker def test_out(self, device, dtype, op): 860*da0073e9SAndroid Build Coastguard Worker # Prefers running in float32 but has a fallback for the first listed supported dtype 861*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker # Ops from python_ref_db point to python decomps that are potentially 864*da0073e9SAndroid Build Coastguard Worker # wrapped with `torch._prims_common.wrappers.out_wrapper`. Unwrap these 865*da0073e9SAndroid Build Coastguard Worker # ops before testing to avoid clashing with OpInfo.supports_out 866*da0073e9SAndroid Build Coastguard Worker if not op.supports_out: 867*da0073e9SAndroid Build Coastguard Worker op = copy.copy(op) 868*da0073e9SAndroid Build Coastguard Worker op.op = _maybe_remove_out_wrapper(op.op) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker for sample in samples: 871*da0073e9SAndroid Build Coastguard Worker # calls it normally to get the expected result 872*da0073e9SAndroid Build Coastguard Worker expected = op(sample.input, *sample.args, **sample.kwargs) 873*da0073e9SAndroid Build Coastguard Worker op_out = partial(op, sample.input, *sample.args, **sample.kwargs) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker # Short-circuits if output is not a single tensor or an 876*da0073e9SAndroid Build Coastguard Worker # iterable of tensors 877*da0073e9SAndroid Build Coastguard Worker if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors( 878*da0073e9SAndroid Build Coastguard Worker expected, include_empty=True 879*da0073e9SAndroid Build Coastguard Worker ): 880*da0073e9SAndroid Build Coastguard Worker self.skipTest( 881*da0073e9SAndroid Build Coastguard Worker "Skipped! Only supports single tensor or iterable of tensor outputs." 882*da0073e9SAndroid Build Coastguard Worker ) 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker # Validates the op doesn't support out if it claims not to 885*da0073e9SAndroid Build Coastguard Worker if not op.supports_out: 886*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 887*da0073e9SAndroid Build Coastguard Worker assert op_out(out=expected) != NotImplemented 888*da0073e9SAndroid Build Coastguard Worker return 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker # A wrapper around map that works with single tensors and always 891*da0073e9SAndroid Build Coastguard Worker # instantiates the map. Used below to apply transforms to 892*da0073e9SAndroid Build Coastguard Worker # single tensor and iterable tensor outputs. 893*da0073e9SAndroid Build Coastguard Worker def _apply_out_transform(fn, out): 894*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 895*da0073e9SAndroid Build Coastguard Worker return fn(out) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker # assumes (see above) that out is an iterable of tensors 898*da0073e9SAndroid Build Coastguard Worker return tuple(map(fn, out)) 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker # Extracts strides from a tensor or iterable of tensors into a tuple 901*da0073e9SAndroid Build Coastguard Worker def _extract_strides(out): 902*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 903*da0073e9SAndroid Build Coastguard Worker return (out.stride(),) 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker # assumes (see above) that out is an iterable of tensors 906*da0073e9SAndroid Build Coastguard Worker return tuple(t.stride() for t in out) 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker # Extracts data pointers from a tensor or iterable of tensors into a tuple 909*da0073e9SAndroid Build Coastguard Worker # NOTE: only extracts on the CPU and CUDA device types since some 910*da0073e9SAndroid Build Coastguard Worker # device types don't have storage 911*da0073e9SAndroid Build Coastguard Worker def _extract_data_ptrs(out): 912*da0073e9SAndroid Build Coastguard Worker if self.device_type != "cpu" and self.device_type != "cuda": 913*da0073e9SAndroid Build Coastguard Worker return () 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 916*da0073e9SAndroid Build Coastguard Worker return (out.data_ptr(),) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker # assumes (see above) that out is an iterable of tensors 919*da0073e9SAndroid Build Coastguard Worker return tuple(t.data_ptr() for t in out) 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker def _compare_out(transform, *, compare_strides_and_data_ptrs=True): 922*da0073e9SAndroid Build Coastguard Worker out = _apply_out_transform(transform, expected) 923*da0073e9SAndroid Build Coastguard Worker original_strides = _extract_strides(out) 924*da0073e9SAndroid Build Coastguard Worker original_ptrs = _extract_data_ptrs(out) 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 927*da0073e9SAndroid Build Coastguard Worker final_strides = _extract_strides(out) 928*da0073e9SAndroid Build Coastguard Worker final_ptrs = _extract_data_ptrs(out) 929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker if compare_strides_and_data_ptrs: 932*da0073e9SAndroid Build Coastguard Worker stride_msg = ( 933*da0073e9SAndroid Build Coastguard Worker "Strides are not the same! " 934*da0073e9SAndroid Build Coastguard Worker f"Original strides were {original_strides} and strides are now {final_strides}" 935*da0073e9SAndroid Build Coastguard Worker ) 936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(original_strides, final_strides, msg=stride_msg) 937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(original_ptrs, final_ptrs) 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker # Case 0: out= with the correct shape, dtype, and device 940*da0073e9SAndroid Build Coastguard Worker # but NaN values for floating point and complex tensors, and 941*da0073e9SAndroid Build Coastguard Worker # maximum values for integer tensors. 942*da0073e9SAndroid Build Coastguard Worker # Expected behavior: out= values have no effect on the computation. 943*da0073e9SAndroid Build Coastguard Worker def _case_zero_transform(t): 944*da0073e9SAndroid Build Coastguard Worker try: 945*da0073e9SAndroid Build Coastguard Worker info = torch.iinfo(t.dtype) 946*da0073e9SAndroid Build Coastguard Worker return torch.full_like(t, info.max) 947*da0073e9SAndroid Build Coastguard Worker except TypeError as te: 948*da0073e9SAndroid Build Coastguard Worker # for non-integer types fills with NaN 949*da0073e9SAndroid Build Coastguard Worker return torch.full_like(t, float("nan")) 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker _compare_out(_case_zero_transform) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker # Case 1: out= with the correct shape, dtype, and device, 954*da0073e9SAndroid Build Coastguard Worker # but noncontiguous. 955*da0073e9SAndroid Build Coastguard Worker # Expected behavior: strides are respected and `out` storage is not changed. 956*da0073e9SAndroid Build Coastguard Worker def _case_one_transform(t): 957*da0073e9SAndroid Build Coastguard Worker return make_tensor( 958*da0073e9SAndroid Build Coastguard Worker t.shape, dtype=t.dtype, device=t.device, noncontiguous=True 959*da0073e9SAndroid Build Coastguard Worker ) 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker _compare_out(_case_one_transform) 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Worker # Case 2: out= with the correct dtype and device, but has no elements. 964*da0073e9SAndroid Build Coastguard Worker # Expected behavior: resize without warning. 965*da0073e9SAndroid Build Coastguard Worker def _case_two_transform(t): 966*da0073e9SAndroid Build Coastguard Worker return make_tensor((0,), dtype=t.dtype, device=t.device) 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker _compare_out(_case_two_transform, compare_strides_and_data_ptrs=False) 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker # Also validates that no warning is thrown when this out is resized 971*da0073e9SAndroid Build Coastguard Worker out = _apply_out_transform(_case_two_transform, expected) 972*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as caught: 973*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 974*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker # Verifies no warning is a resize warning 977*da0073e9SAndroid Build Coastguard Worker for w in caught: 978*da0073e9SAndroid Build Coastguard Worker if "An output with one or more elements" in str(w.message): 979*da0073e9SAndroid Build Coastguard Worker self.fail( 980*da0073e9SAndroid Build Coastguard Worker "Resizing an out= argument with no elements threw a resize warning!" 981*da0073e9SAndroid Build Coastguard Worker ) 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker # Case 3: out= with correct shape and dtype, but wrong device. 984*da0073e9SAndroid Build Coastguard Worker wrong_device = None 985*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type != "cpu": 986*da0073e9SAndroid Build Coastguard Worker wrong_device = "cpu" 987*da0073e9SAndroid Build Coastguard Worker elif torch.cuda.is_available(): 988*da0073e9SAndroid Build Coastguard Worker wrong_device = "cuda" 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker factory_fn_msg = ( 991*da0073e9SAndroid Build Coastguard Worker "\n\nNOTE: If your op is a factory function (i.e., it accepts TensorOptions) you should mark its " 992*da0073e9SAndroid Build Coastguard Worker "OpInfo with `is_factory_function=True`." 993*da0073e9SAndroid Build Coastguard Worker ) 994*da0073e9SAndroid Build Coastguard Worker if wrong_device is not None: 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker def _case_three_transform(t): 997*da0073e9SAndroid Build Coastguard Worker return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker out = _apply_out_transform(_case_three_transform, expected) 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker if op.is_factory_function and sample.kwargs.get("device", None) is None: 1002*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 1003*da0073e9SAndroid Build Coastguard Worker else: 1004*da0073e9SAndroid Build Coastguard Worker msg_fail = ( 1005*da0073e9SAndroid Build Coastguard Worker f"Expected RuntimeError when calling with input.device={device} and out.device={wrong_device}." 1006*da0073e9SAndroid Build Coastguard Worker ) + factory_fn_msg 1007*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError, msg=msg_fail): 1008*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker # Case 4: out= with correct shape and device, but a dtype 1011*da0073e9SAndroid Build Coastguard Worker # that output cannot be "safely" cast to (long). 1012*da0073e9SAndroid Build Coastguard Worker # Expected behavior: error. 1013*da0073e9SAndroid Build Coastguard Worker # NOTE: this case is filtered by dtype since some ops produce 1014*da0073e9SAndroid Build Coastguard Worker # bool tensors, for example, which can be safely cast to any 1015*da0073e9SAndroid Build Coastguard Worker # dtype. It is applied when single tensors are floating point or complex 1016*da0073e9SAndroid Build Coastguard Worker # dtypes, or if an op returns multiple tensors when at least one such 1017*da0073e9SAndroid Build Coastguard Worker # tensor is a floating point or complex dtype. 1018*da0073e9SAndroid Build Coastguard Worker _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) 1019*da0073e9SAndroid Build Coastguard Worker if ( 1020*da0073e9SAndroid Build Coastguard Worker isinstance(expected, torch.Tensor) 1021*da0073e9SAndroid Build Coastguard Worker and expected.dtype in _dtypes 1022*da0073e9SAndroid Build Coastguard Worker or ( 1023*da0073e9SAndroid Build Coastguard Worker not isinstance(expected, torch.Tensor) 1024*da0073e9SAndroid Build Coastguard Worker and any(t.dtype in _dtypes for t in expected) 1025*da0073e9SAndroid Build Coastguard Worker ) 1026*da0073e9SAndroid Build Coastguard Worker ): 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Worker def _case_four_transform(t): 1029*da0073e9SAndroid Build Coastguard Worker return make_tensor(t.shape, dtype=torch.long, device=t.device) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker out = _apply_out_transform(_case_four_transform, expected) 1032*da0073e9SAndroid Build Coastguard Worker msg_fail = "Expected RuntimeError when doing an unsafe cast!" 1033*da0073e9SAndroid Build Coastguard Worker msg_fail = ( 1034*da0073e9SAndroid Build Coastguard Worker msg_fail 1035*da0073e9SAndroid Build Coastguard Worker if not isinstance(expected, torch.Tensor) 1036*da0073e9SAndroid Build Coastguard Worker else ( 1037*da0073e9SAndroid Build Coastguard Worker "Expected RuntimeError when doing an unsafe cast from a result of dtype " 1038*da0073e9SAndroid Build Coastguard Worker f"{expected.dtype} into an out= with dtype torch.long" 1039*da0073e9SAndroid Build Coastguard Worker ) 1040*da0073e9SAndroid Build Coastguard Worker ) + factory_fn_msg 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker if op.is_factory_function and sample.kwargs.get("dtype", None) is None: 1043*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 1044*da0073e9SAndroid Build Coastguard Worker else: 1045*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError, msg=msg_fail): 1046*da0073e9SAndroid Build Coastguard Worker op_out(out=out) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker @ops( 1049*da0073e9SAndroid Build Coastguard Worker [ 1050*da0073e9SAndroid Build Coastguard Worker op 1051*da0073e9SAndroid Build Coastguard Worker for op in op_db 1052*da0073e9SAndroid Build Coastguard Worker if op.supports_out and (op.supports_autograd or op.is_factory_function) 1053*da0073e9SAndroid Build Coastguard Worker ], 1054*da0073e9SAndroid Build Coastguard Worker dtypes=OpDTypes.supported, 1055*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.float, torch.cfloat], 1056*da0073e9SAndroid Build Coastguard Worker ) 1057*da0073e9SAndroid Build Coastguard Worker def test_out_requires_grad_error(self, device, dtype, op): 1058*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, op.sample_inputs(device, dtype)) 1059*da0073e9SAndroid Build Coastguard Worker 1060*da0073e9SAndroid Build Coastguard Worker # Call op to get prototype for out arguments 1061*da0073e9SAndroid Build Coastguard Worker expect = op(sample.input, *sample.args, **sample.kwargs) 1062*da0073e9SAndroid Build Coastguard Worker any_requires_grad = False 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker def set_requires_grad(x): 1065*da0073e9SAndroid Build Coastguard Worker nonlocal any_requires_grad 1066*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and ( 1067*da0073e9SAndroid Build Coastguard Worker x.is_floating_point() or x.is_complex() 1068*da0073e9SAndroid Build Coastguard Worker ): 1069*da0073e9SAndroid Build Coastguard Worker any_requires_grad = True 1070*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(True) 1071*da0073e9SAndroid Build Coastguard Worker return x 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker out = pytree.tree_map_(set_requires_grad, expect) 1074*da0073e9SAndroid Build Coastguard Worker if not any_requires_grad: 1075*da0073e9SAndroid Build Coastguard Worker # Skip ops without any floating point outputs, e.g. isnan 1076*da0073e9SAndroid Build Coastguard Worker return 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker msg = ( 1079*da0073e9SAndroid Build Coastguard Worker "functions with out=... arguments don't support automatic " 1080*da0073e9SAndroid Build Coastguard Worker "differentiation, but one of the arguments requires grad." 1081*da0073e9SAndroid Build Coastguard Worker ) 1082*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError, msg=msg): 1083*da0073e9SAndroid Build Coastguard Worker op(sample.input, *sample.args, **sample.kwargs, out=out) 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker @ops(filter(reduction_dtype_filter, ops_and_refs), dtypes=(torch.int16,)) 1086*da0073e9SAndroid Build Coastguard Worker def test_out_integral_dtype(self, device, dtype, op): 1087*da0073e9SAndroid Build Coastguard Worker def helper(with_out, expectFail, op_to_test, inputs, *args, **kwargs): 1088*da0073e9SAndroid Build Coastguard Worker out = None 1089*da0073e9SAndroid Build Coastguard Worker try: 1090*da0073e9SAndroid Build Coastguard Worker if with_out: 1091*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, dtype=torch.int32, device=device) 1092*da0073e9SAndroid Build Coastguard Worker op_to_test(inputs, *args, out=out, **kwargs) 1093*da0073e9SAndroid Build Coastguard Worker else: 1094*da0073e9SAndroid Build Coastguard Worker out = op_to_test(inputs, *args, **kwargs) 1095*da0073e9SAndroid Build Coastguard Worker self.assertFalse(expectFail) 1096*da0073e9SAndroid Build Coastguard Worker except RuntimeError as err: 1097*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1098*da0073e9SAndroid Build Coastguard Worker str(err), "dtype argument and out dtype must match in reduction" 1099*da0073e9SAndroid Build Coastguard Worker ) 1100*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expectFail) 1101*da0073e9SAndroid Build Coastguard Worker return out 1102*da0073e9SAndroid Build Coastguard Worker 1103*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 1104*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1105*da0073e9SAndroid Build Coastguard Worker if "dtype" not in sample.kwargs: 1106*da0073e9SAndroid Build Coastguard Worker helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1107*da0073e9SAndroid Build Coastguard Worker helper(True, False, op, sample.input, *sample.args, **sample.kwargs) 1108*da0073e9SAndroid Build Coastguard Worker sample.kwargs["dtype"] = torch.int16 1109*da0073e9SAndroid Build Coastguard Worker helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1110*da0073e9SAndroid Build Coastguard Worker helper(True, True, op, sample.input, *sample.args, **sample.kwargs) 1111*da0073e9SAndroid Build Coastguard Worker sample.kwargs["dtype"] = torch.int32 1112*da0073e9SAndroid Build Coastguard Worker helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1113*da0073e9SAndroid Build Coastguard Worker helper(True, False, op, sample.input, *sample.args, **sample.kwargs) 1114*da0073e9SAndroid Build Coastguard Worker else: 1115*da0073e9SAndroid Build Coastguard Worker helper(False, False, op, sample.input, *sample.args, **sample.kwargs) 1116*da0073e9SAndroid Build Coastguard Worker helper( 1117*da0073e9SAndroid Build Coastguard Worker True, 1118*da0073e9SAndroid Build Coastguard Worker sample.kwargs["dtype"] != torch.int32, 1119*da0073e9SAndroid Build Coastguard Worker op, 1120*da0073e9SAndroid Build Coastguard Worker sample.input, 1121*da0073e9SAndroid Build Coastguard Worker *sample.args, 1122*da0073e9SAndroid Build Coastguard Worker **sample.kwargs, 1123*da0073e9SAndroid Build Coastguard Worker ) 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker # Tests that the forward and backward passes of operations produce the 1126*da0073e9SAndroid Build Coastguard Worker # same values for the cross-product of op variants (method, inplace) 1127*da0073e9SAndroid Build Coastguard Worker # against eager's gold standard op function variant 1128*da0073e9SAndroid Build Coastguard Worker @_variant_ops(op_db) 1129*da0073e9SAndroid Build Coastguard Worker def test_variant_consistency_eager(self, device, dtype, op): 1130*da0073e9SAndroid Build Coastguard Worker # Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases) 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker method = op.method_variant 1133*da0073e9SAndroid Build Coastguard Worker inplace = op.inplace_variant 1134*da0073e9SAndroid Build Coastguard Worker operator = op.operator_variant 1135*da0073e9SAndroid Build Coastguard Worker inplace_operator = op.inplace_operator_variant 1136*da0073e9SAndroid Build Coastguard Worker 1137*da0073e9SAndroid Build Coastguard Worker # list of all inplace ops: inplace variant + alias inplace variants if exist 1138*da0073e9SAndroid Build Coastguard Worker inplace_ops = [inplace, inplace_operator] 1139*da0073e9SAndroid Build Coastguard Worker variants = [method, inplace, operator, inplace_operator] 1140*da0073e9SAndroid Build Coastguard Worker operators = [operator, inplace_operator] 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker for a_op in op.aliases: 1143*da0073e9SAndroid Build Coastguard Worker variants.append(a_op.op) 1144*da0073e9SAndroid Build Coastguard Worker variants.append(a_op.method_variant) 1145*da0073e9SAndroid Build Coastguard Worker variants.append(a_op.inplace_variant) 1146*da0073e9SAndroid Build Coastguard Worker inplace_ops.append(a_op.inplace_variant) 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker inplace_variants = tuple(filter(None, inplace_ops)) 1149*da0073e9SAndroid Build Coastguard Worker variants = tuple(filter(None, variants)) 1150*da0073e9SAndroid Build Coastguard Worker operators = tuple(filter(None, operators)) 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker _requires_grad = dtype in op.supported_backward_dtypes( 1153*da0073e9SAndroid Build Coastguard Worker torch.device(device).type 1154*da0073e9SAndroid Build Coastguard Worker ) 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Worker include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex 1157*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs( 1158*da0073e9SAndroid Build Coastguard Worker device, 1159*da0073e9SAndroid Build Coastguard Worker dtype, 1160*da0073e9SAndroid Build Coastguard Worker requires_grad=_requires_grad, 1161*da0073e9SAndroid Build Coastguard Worker include_conjugated_inputs=include_conjugated_inputs, 1162*da0073e9SAndroid Build Coastguard Worker ) 1163*da0073e9SAndroid Build Coastguard Worker samples = list(samples) 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker def _test_consistency_helper(samples, variants): 1166*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1167*da0073e9SAndroid Build Coastguard Worker # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList 1168*da0073e9SAndroid Build Coastguard Worker tensor = ( 1169*da0073e9SAndroid Build Coastguard Worker sample.input 1170*da0073e9SAndroid Build Coastguard Worker if isinstance(sample.input, torch.Tensor) 1171*da0073e9SAndroid Build Coastguard Worker else sample.input[0] 1172*da0073e9SAndroid Build Coastguard Worker ) 1173*da0073e9SAndroid Build Coastguard Worker 1174*da0073e9SAndroid Build Coastguard Worker # Computes function forward and backward values 1175*da0073e9SAndroid Build Coastguard Worker tensor.grad = None 1176*da0073e9SAndroid Build Coastguard Worker expected_forward = op(sample.input, *sample.args, **sample.kwargs) 1177*da0073e9SAndroid Build Coastguard Worker expected_grad = None 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker output_process_fn_grad = ( 1180*da0073e9SAndroid Build Coastguard Worker sample.output_process_fn_grad 1181*da0073e9SAndroid Build Coastguard Worker if sample.output_process_fn_grad 1182*da0073e9SAndroid Build Coastguard Worker else lambda x: x 1183*da0073e9SAndroid Build Coastguard Worker ) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker # Skips inplace variants if the output dtype is not the same as 1186*da0073e9SAndroid Build Coastguard Worker # the input dtype 1187*da0073e9SAndroid Build Coastguard Worker skip_inplace = False 1188*da0073e9SAndroid Build Coastguard Worker if ( 1189*da0073e9SAndroid Build Coastguard Worker isinstance(expected_forward, torch.Tensor) 1190*da0073e9SAndroid Build Coastguard Worker and expected_forward.dtype is not tensor.dtype 1191*da0073e9SAndroid Build Coastguard Worker ): 1192*da0073e9SAndroid Build Coastguard Worker skip_inplace = True 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker # TODO: backward consistency only supported for single tensor outputs 1195*da0073e9SAndroid Build Coastguard Worker # TODO: backward consistency only checked on sample.input, not all 1196*da0073e9SAndroid Build Coastguard Worker # tensor inputs 1197*da0073e9SAndroid Build Coastguard Worker # TODO: update to handle checking grads of all tensor inputs as 1198*da0073e9SAndroid Build Coastguard Worker # derived from each tensor output 1199*da0073e9SAndroid Build Coastguard Worker if isinstance( 1200*da0073e9SAndroid Build Coastguard Worker expected_forward, torch.Tensor 1201*da0073e9SAndroid Build Coastguard Worker ) and dtype in op.supported_backward_dtypes(torch.device(device).type): 1202*da0073e9SAndroid Build Coastguard Worker out = output_process_fn_grad(expected_forward).sum() 1203*da0073e9SAndroid Build Coastguard Worker if out.dtype.is_complex: 1204*da0073e9SAndroid Build Coastguard Worker out = out.abs() 1205*da0073e9SAndroid Build Coastguard Worker out.backward() 1206*da0073e9SAndroid Build Coastguard Worker expected_grad = tensor.grad 1207*da0073e9SAndroid Build Coastguard Worker 1208*da0073e9SAndroid Build Coastguard Worker # Test eager consistency 1209*da0073e9SAndroid Build Coastguard Worker for variant in variants: 1210*da0073e9SAndroid Build Coastguard Worker # Skips inplace ops 1211*da0073e9SAndroid Build Coastguard Worker if variant in inplace_ops and skip_inplace: 1212*da0073e9SAndroid Build Coastguard Worker continue 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker # Compares variant's forward 1215*da0073e9SAndroid Build Coastguard Worker # Note: copies the to-be-modified input when testing the inplace variant 1216*da0073e9SAndroid Build Coastguard Worker tensor.grad = None 1217*da0073e9SAndroid Build Coastguard Worker cloned = ( 1218*da0073e9SAndroid Build Coastguard Worker clone_input_helper(sample.input) 1219*da0073e9SAndroid Build Coastguard Worker if variant in inplace_ops 1220*da0073e9SAndroid Build Coastguard Worker else sample.input 1221*da0073e9SAndroid Build Coastguard Worker ) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker if variant in inplace_ops and sample.broadcasts_input: 1224*da0073e9SAndroid Build Coastguard Worker with self.assertRaises( 1225*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1226*da0073e9SAndroid Build Coastguard Worker msg=( 1227*da0073e9SAndroid Build Coastguard Worker "inplace variant either incorrectly allowed " 1228*da0073e9SAndroid Build Coastguard Worker f"resizing or you have marked the sample {sample.summary()}" 1229*da0073e9SAndroid Build Coastguard Worker " incorrectly with `broadcasts_self=True" 1230*da0073e9SAndroid Build Coastguard Worker ), 1231*da0073e9SAndroid Build Coastguard Worker ): 1232*da0073e9SAndroid Build Coastguard Worker variant_forward = variant( 1233*da0073e9SAndroid Build Coastguard Worker cloned, *sample.args, **sample.kwargs 1234*da0073e9SAndroid Build Coastguard Worker ) 1235*da0073e9SAndroid Build Coastguard Worker continue 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker if variant in operators and sample.kwargs: 1238*da0073e9SAndroid Build Coastguard Worker # skip samples with kwargs for operator variants 1239*da0073e9SAndroid Build Coastguard Worker continue 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Worker variant_forward = variant(cloned, *sample.args, **sample.kwargs) 1242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_forward, variant_forward) 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker # Compares variant's backward 1245*da0073e9SAndroid Build Coastguard Worker if expected_grad is not None and ( 1246*da0073e9SAndroid Build Coastguard Worker variant not in inplace_ops or op.supports_inplace_autograd 1247*da0073e9SAndroid Build Coastguard Worker ): 1248*da0073e9SAndroid Build Coastguard Worker out = output_process_fn_grad(variant_forward).sum() 1249*da0073e9SAndroid Build Coastguard Worker if out.dtype.is_complex: 1250*da0073e9SAndroid Build Coastguard Worker out = out.abs() 1251*da0073e9SAndroid Build Coastguard Worker out.backward() 1252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_grad, tensor.grad) 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Worker _test_consistency_helper(samples, variants) 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker def _test_inplace_preserve_storage(samples, variants): 1257*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1258*da0073e9SAndroid Build Coastguard Worker # Skips inplace variants if the output dtype is not the same as 1259*da0073e9SAndroid Build Coastguard Worker # the input dtype 1260*da0073e9SAndroid Build Coastguard Worker expected_forward = op(sample.input, *sample.args, **sample.kwargs) 1261*da0073e9SAndroid Build Coastguard Worker tensor = ( 1262*da0073e9SAndroid Build Coastguard Worker sample.input 1263*da0073e9SAndroid Build Coastguard Worker if isinstance(sample.input, torch.Tensor) 1264*da0073e9SAndroid Build Coastguard Worker else sample.input[0] 1265*da0073e9SAndroid Build Coastguard Worker ) 1266*da0073e9SAndroid Build Coastguard Worker skip_inplace = False 1267*da0073e9SAndroid Build Coastguard Worker if ( 1268*da0073e9SAndroid Build Coastguard Worker isinstance(expected_forward, torch.Tensor) 1269*da0073e9SAndroid Build Coastguard Worker and expected_forward.dtype is not tensor.dtype 1270*da0073e9SAndroid Build Coastguard Worker ): 1271*da0073e9SAndroid Build Coastguard Worker skip_inplace = True 1272*da0073e9SAndroid Build Coastguard Worker if skip_inplace: 1273*da0073e9SAndroid Build Coastguard Worker return 1274*da0073e9SAndroid Build Coastguard Worker for variant in variants: 1275*da0073e9SAndroid Build Coastguard Worker cloned = ( 1276*da0073e9SAndroid Build Coastguard Worker clone_input_helper(sample.input) 1277*da0073e9SAndroid Build Coastguard Worker if variant in inplace_ops 1278*da0073e9SAndroid Build Coastguard Worker else sample.input 1279*da0073e9SAndroid Build Coastguard Worker ) 1280*da0073e9SAndroid Build Coastguard Worker inp_tensor = ( 1281*da0073e9SAndroid Build Coastguard Worker cloned if isinstance(cloned, torch.Tensor) else cloned[0] 1282*da0073e9SAndroid Build Coastguard Worker ) 1283*da0073e9SAndroid Build Coastguard Worker data_ptr = inp_tensor.data_ptr() 1284*da0073e9SAndroid Build Coastguard Worker if variant in operators and sample.kwargs: 1285*da0073e9SAndroid Build Coastguard Worker # skip samples with kwargs for operator variants 1286*da0073e9SAndroid Build Coastguard Worker continue 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker variant_forward = variant(cloned, *sample.args, **sample.kwargs) 1289*da0073e9SAndroid Build Coastguard Worker # TODO Support non-tensor outputs if they exist for inplace ops 1290*da0073e9SAndroid Build Coastguard Worker if isinstance(variant_forward, torch.Tensor): 1291*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1292*da0073e9SAndroid Build Coastguard Worker data_ptr, variant_forward.data_ptr(), atol=0, rtol=0 1293*da0073e9SAndroid Build Coastguard Worker ) 1294*da0073e9SAndroid Build Coastguard Worker else: 1295*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1296*da0073e9SAndroid Build Coastguard Worker False, 1297*da0073e9SAndroid Build Coastguard Worker "Non-tensor outputs for inplace ops are not supported", 1298*da0073e9SAndroid Build Coastguard Worker ) 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker if len(inplace_ops) > 0: 1301*da0073e9SAndroid Build Coastguard Worker inplace_samples = list( 1302*da0073e9SAndroid Build Coastguard Worker filter(lambda sample: not sample.broadcasts_input, samples) 1303*da0073e9SAndroid Build Coastguard Worker ) 1304*da0073e9SAndroid Build Coastguard Worker _test_inplace_preserve_storage(inplace_samples, inplace_variants) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker # Reference testing for operations in complex32 against complex64. 1307*da0073e9SAndroid Build Coastguard Worker # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. 1308*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.complex32,)) 1309*da0073e9SAndroid Build Coastguard Worker def test_complex_half_reference_testing(self, device, dtype, op): 1310*da0073e9SAndroid Build Coastguard Worker if not op.supports_dtype(torch.complex32, device): 1311*da0073e9SAndroid Build Coastguard Worker unittest.skip("Does not support complex32") 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device, dtype): 1314*da0073e9SAndroid Build Coastguard Worker actual = op(sample.input, *sample.args, **sample.kwargs) 1315*da0073e9SAndroid Build Coastguard Worker # sample.transform applies the lambda to torch.Tensor and torch.dtype. 1316*da0073e9SAndroid Build Coastguard Worker # However, we only want to apply it to Tensors with dtype `torch.complex32`.. 1317*da0073e9SAndroid Build Coastguard Worker transformed_sample = sample.transform( 1318*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(torch.complex64) 1319*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and x.dtype is torch.complex32 1320*da0073e9SAndroid Build Coastguard Worker else x 1321*da0073e9SAndroid Build Coastguard Worker ) 1322*da0073e9SAndroid Build Coastguard Worker expected = op( 1323*da0073e9SAndroid Build Coastguard Worker transformed_sample.input, 1324*da0073e9SAndroid Build Coastguard Worker *transformed_sample.args, 1325*da0073e9SAndroid Build Coastguard Worker **transformed_sample.kwargs, 1326*da0073e9SAndroid Build Coastguard Worker ) 1327*da0073e9SAndroid Build Coastguard Worker # Since range of chalf is much less compared to cfloat, 1328*da0073e9SAndroid Build Coastguard Worker # we get `inf`s easily (eg. with `pow`, `exp`), 1329*da0073e9SAndroid Build Coastguard Worker # so we cast `cfloat` back to `chalf`. 1330*da0073e9SAndroid Build Coastguard Worker expected = tree_map( 1331*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(torch.complex32) 1332*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and x.dtype is torch.complex64 1333*da0073e9SAndroid Build Coastguard Worker else x, 1334*da0073e9SAndroid Build Coastguard Worker expected, 1335*da0073e9SAndroid Build Coastguard Worker ) 1336*da0073e9SAndroid Build Coastguard Worker 1337*da0073e9SAndroid Build Coastguard Worker # `exact_dtype` is False because for ops like real, imag 1338*da0073e9SAndroid Build Coastguard Worker # we get different dtypes for `actual` and `expected` 1339*da0073e9SAndroid Build Coastguard Worker # `chalf` input -> `half` output 1340*da0073e9SAndroid Build Coastguard Worker # `cfloat` input -> `float` output 1341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.bool,)) 1344*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior") 1345*da0073e9SAndroid Build Coastguard Worker def test_non_standard_bool_values(self, device, dtype, op): 1346*da0073e9SAndroid Build Coastguard Worker # Test boolean values other than 0x00 and 0x01 (gh-54789) 1347*da0073e9SAndroid Build Coastguard Worker def convert_boolean_tensors(x): 1348*da0073e9SAndroid Build Coastguard Worker if not isinstance(x, torch.Tensor) or x.dtype != torch.bool: 1349*da0073e9SAndroid Build Coastguard Worker return x 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker # Map False -> 0 and True -> Random value in [2, 255] 1352*da0073e9SAndroid Build Coastguard Worker true_vals = torch.randint( 1353*da0073e9SAndroid Build Coastguard Worker 2, 255, x.shape, dtype=torch.uint8, device=x.device 1354*da0073e9SAndroid Build Coastguard Worker ) 1355*da0073e9SAndroid Build Coastguard Worker false_vals = torch.zeros((), dtype=torch.uint8, device=x.device) 1356*da0073e9SAndroid Build Coastguard Worker x_int = torch.where(x, true_vals, false_vals) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker ret = x_int.view(torch.bool) 1359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret, x) 1360*da0073e9SAndroid Build Coastguard Worker return ret 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device, dtype): 1363*da0073e9SAndroid Build Coastguard Worker expect = op(sample.input, *sample.args, **sample.kwargs) 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker transformed = sample.transform(convert_boolean_tensors) 1366*da0073e9SAndroid Build Coastguard Worker actual = op(transformed.input, *transformed.args, **transformed.kwargs) 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker # Validates that each OpInfo specifies its forward and backward dtypes 1371*da0073e9SAndroid Build Coastguard Worker # correctly for CPU and CUDA devices 1372*da0073e9SAndroid Build Coastguard Worker @skipMeta 1373*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 1374*da0073e9SAndroid Build Coastguard Worker @ops(ops_and_refs, dtypes=OpDTypes.none) 1375*da0073e9SAndroid Build Coastguard Worker def test_dtypes(self, device, op): 1376*da0073e9SAndroid Build Coastguard Worker # Check complex32 support only if the op claims. 1377*da0073e9SAndroid Build Coastguard Worker # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. 1378*da0073e9SAndroid Build Coastguard Worker device_type = torch.device(device).type 1379*da0073e9SAndroid Build Coastguard Worker include_complex32 = ( 1380*da0073e9SAndroid Build Coastguard Worker (torch.complex32,) 1381*da0073e9SAndroid Build Coastguard Worker if op.supports_dtype(torch.complex32, device_type) 1382*da0073e9SAndroid Build Coastguard Worker else () 1383*da0073e9SAndroid Build Coastguard Worker ) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker # dtypes to try to backward in 1386*da0073e9SAndroid Build Coastguard Worker allowed_backward_dtypes = floating_and_complex_types_and( 1387*da0073e9SAndroid Build Coastguard Worker *((torch.half, torch.bfloat16) + include_complex32) 1388*da0073e9SAndroid Build Coastguard Worker ) 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker # lists for (un)supported dtypes 1391*da0073e9SAndroid Build Coastguard Worker supported_dtypes = set() 1392*da0073e9SAndroid Build Coastguard Worker unsupported_dtypes = set() 1393*da0073e9SAndroid Build Coastguard Worker supported_backward_dtypes = set() 1394*da0073e9SAndroid Build Coastguard Worker unsupported_backward_dtypes = set() 1395*da0073e9SAndroid Build Coastguard Worker dtype_error: Dict[torch.dtype, Exception] = {} 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker def unsupported(dtype, e): 1398*da0073e9SAndroid Build Coastguard Worker dtype_error[dtype] = e 1399*da0073e9SAndroid Build Coastguard Worker unsupported_dtypes.add(dtype) 1400*da0073e9SAndroid Build Coastguard Worker if dtype in allowed_backward_dtypes: 1401*da0073e9SAndroid Build Coastguard Worker unsupported_backward_dtypes.add(dtype) 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker for dtype in all_types_and_complex_and( 1404*da0073e9SAndroid Build Coastguard Worker *((torch.half, torch.bfloat16, torch.bool) + include_complex32) 1405*da0073e9SAndroid Build Coastguard Worker ): 1406*da0073e9SAndroid Build Coastguard Worker # tries to acquire samples - failure indicates lack of support 1407*da0073e9SAndroid Build Coastguard Worker requires_grad = dtype in allowed_backward_dtypes 1408*da0073e9SAndroid Build Coastguard Worker try: 1409*da0073e9SAndroid Build Coastguard Worker samples = tuple( 1410*da0073e9SAndroid Build Coastguard Worker op.sample_inputs(device, dtype, requires_grad=requires_grad) 1411*da0073e9SAndroid Build Coastguard Worker ) 1412*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1413*da0073e9SAndroid Build Coastguard Worker unsupported(dtype, e) 1414*da0073e9SAndroid Build Coastguard Worker continue 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1417*da0073e9SAndroid Build Coastguard Worker # tries to call operator with the sample - failure indicates 1418*da0073e9SAndroid Build Coastguard Worker # lack of support 1419*da0073e9SAndroid Build Coastguard Worker try: 1420*da0073e9SAndroid Build Coastguard Worker result = op(sample.input, *sample.args, **sample.kwargs) 1421*da0073e9SAndroid Build Coastguard Worker supported_dtypes.add(dtype) 1422*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1423*da0073e9SAndroid Build Coastguard Worker # NOTE: some ops will fail in forward if their inputs 1424*da0073e9SAndroid Build Coastguard Worker # require grad but they don't support computing the gradient 1425*da0073e9SAndroid Build Coastguard Worker # in that type! This is a bug in the op! 1426*da0073e9SAndroid Build Coastguard Worker unsupported(dtype, e) 1427*da0073e9SAndroid Build Coastguard Worker continue 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker # Checks for backward support in the same dtype, if the input has 1430*da0073e9SAndroid Build Coastguard Worker # one or more tensors requiring grad 1431*da0073e9SAndroid Build Coastguard Worker def _tensor_requires_grad(x): 1432*da0073e9SAndroid Build Coastguard Worker if isinstance(x, dict): 1433*da0073e9SAndroid Build Coastguard Worker for v in x.values(): 1434*da0073e9SAndroid Build Coastguard Worker if _tensor_requires_grad(v): 1435*da0073e9SAndroid Build Coastguard Worker return True 1436*da0073e9SAndroid Build Coastguard Worker if isinstance(x, (list, tuple)): 1437*da0073e9SAndroid Build Coastguard Worker for a in x: 1438*da0073e9SAndroid Build Coastguard Worker if _tensor_requires_grad(a): 1439*da0073e9SAndroid Build Coastguard Worker return True 1440*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and x.requires_grad: 1441*da0073e9SAndroid Build Coastguard Worker return True 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker return False 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker requires_grad = ( 1446*da0073e9SAndroid Build Coastguard Worker _tensor_requires_grad(sample.input) 1447*da0073e9SAndroid Build Coastguard Worker or _tensor_requires_grad(sample.args) 1448*da0073e9SAndroid Build Coastguard Worker or _tensor_requires_grad(sample.kwargs) 1449*da0073e9SAndroid Build Coastguard Worker ) 1450*da0073e9SAndroid Build Coastguard Worker if not requires_grad: 1451*da0073e9SAndroid Build Coastguard Worker continue 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker try: 1454*da0073e9SAndroid Build Coastguard Worker result = sample.output_process_fn_grad(result) 1455*da0073e9SAndroid Build Coastguard Worker if isinstance(result, torch.Tensor): 1456*da0073e9SAndroid Build Coastguard Worker backward_tensor = result 1457*da0073e9SAndroid Build Coastguard Worker elif isinstance(result, Sequence) and isinstance( 1458*da0073e9SAndroid Build Coastguard Worker result[0], torch.Tensor 1459*da0073e9SAndroid Build Coastguard Worker ): 1460*da0073e9SAndroid Build Coastguard Worker backward_tensor = result[0] 1461*da0073e9SAndroid Build Coastguard Worker else: 1462*da0073e9SAndroid Build Coastguard Worker continue 1463*da0073e9SAndroid Build Coastguard Worker 1464*da0073e9SAndroid Build Coastguard Worker # Note: this grad may not have the same dtype as dtype 1465*da0073e9SAndroid Build Coastguard Worker # For functions like complex (float -> complex) or abs 1466*da0073e9SAndroid Build Coastguard Worker # (complex -> float) the grad tensor will have a 1467*da0073e9SAndroid Build Coastguard Worker # different dtype than the input. 1468*da0073e9SAndroid Build Coastguard Worker # For simplicity, this is still modeled as these ops 1469*da0073e9SAndroid Build Coastguard Worker # supporting grad in the input dtype. 1470*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(backward_tensor) 1471*da0073e9SAndroid Build Coastguard Worker backward_tensor.backward(grad) 1472*da0073e9SAndroid Build Coastguard Worker supported_backward_dtypes.add(dtype) 1473*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1474*da0073e9SAndroid Build Coastguard Worker dtype_error[dtype] = e 1475*da0073e9SAndroid Build Coastguard Worker unsupported_backward_dtypes.add(dtype) 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker # Checks that dtypes are listed correctly and generates an informative 1478*da0073e9SAndroid Build Coastguard Worker # error message 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker supported_forward = supported_dtypes - unsupported_dtypes 1481*da0073e9SAndroid Build Coastguard Worker partially_supported_forward = supported_dtypes & unsupported_dtypes 1482*da0073e9SAndroid Build Coastguard Worker unsupported_forward = unsupported_dtypes - supported_dtypes 1483*da0073e9SAndroid Build Coastguard Worker supported_backward = supported_backward_dtypes - unsupported_backward_dtypes 1484*da0073e9SAndroid Build Coastguard Worker partially_supported_backward = ( 1485*da0073e9SAndroid Build Coastguard Worker supported_backward_dtypes & unsupported_backward_dtypes 1486*da0073e9SAndroid Build Coastguard Worker ) 1487*da0073e9SAndroid Build Coastguard Worker unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker device_type = torch.device(device).type 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker claimed_forward = set(op.supported_dtypes(device_type)) 1492*da0073e9SAndroid Build Coastguard Worker supported_but_unclaimed_forward = supported_forward - claimed_forward 1493*da0073e9SAndroid Build Coastguard Worker claimed_but_unsupported_forward = claimed_forward & unsupported_forward 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker claimed_backward = set(op.supported_backward_dtypes(device_type)) 1496*da0073e9SAndroid Build Coastguard Worker supported_but_unclaimed_backward = supported_backward - claimed_backward 1497*da0073e9SAndroid Build Coastguard Worker claimed_but_unsupported_backward = claimed_backward & unsupported_backward 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker # Partially supporting a dtype is not an error, but we print a warning 1500*da0073e9SAndroid Build Coastguard Worker if (len(partially_supported_forward) + len(partially_supported_backward)) > 0: 1501*da0073e9SAndroid Build Coastguard Worker msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n" 1502*da0073e9SAndroid Build Coastguard Worker if len(partially_supported_forward) > 0: 1503*da0073e9SAndroid Build Coastguard Worker msg = ( 1504*da0073e9SAndroid Build Coastguard Worker msg 1505*da0073e9SAndroid Build Coastguard Worker + f"The following dtypes only worked on some samples during forward: {partially_supported_forward}.\n" 1506*da0073e9SAndroid Build Coastguard Worker ) 1507*da0073e9SAndroid Build Coastguard Worker if len(partially_supported_backward) > 0: 1508*da0073e9SAndroid Build Coastguard Worker msg = ( 1509*da0073e9SAndroid Build Coastguard Worker msg 1510*da0073e9SAndroid Build Coastguard Worker + f"The following dtypes only worked on some samples during backward: {partially_supported_backward}.\n" 1511*da0073e9SAndroid Build Coastguard Worker ) 1512*da0073e9SAndroid Build Coastguard Worker print(msg) 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker if ( 1515*da0073e9SAndroid Build Coastguard Worker len(supported_but_unclaimed_forward) 1516*da0073e9SAndroid Build Coastguard Worker + len(claimed_but_unsupported_forward) 1517*da0073e9SAndroid Build Coastguard Worker + len(supported_but_unclaimed_backward) 1518*da0073e9SAndroid Build Coastguard Worker + len(claimed_but_unsupported_backward) 1519*da0073e9SAndroid Build Coastguard Worker ) == 0: 1520*da0073e9SAndroid Build Coastguard Worker return 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Worker # Reference operators often support additional dtypes, and that's OK 1523*da0073e9SAndroid Build Coastguard Worker if op in python_ref_db: 1524*da0073e9SAndroid Build Coastguard Worker if ( 1525*da0073e9SAndroid Build Coastguard Worker len(claimed_but_unsupported_forward) 1526*da0073e9SAndroid Build Coastguard Worker + len(claimed_but_unsupported_backward) 1527*da0073e9SAndroid Build Coastguard Worker ) == 0: 1528*da0073e9SAndroid Build Coastguard Worker return 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker # Generates error msg 1531*da0073e9SAndroid Build Coastguard Worker msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n" 1532*da0073e9SAndroid Build Coastguard Worker if len(supported_but_unclaimed_forward) > 0: 1533*da0073e9SAndroid Build Coastguard Worker msg = ( 1534*da0073e9SAndroid Build Coastguard Worker msg 1535*da0073e9SAndroid Build Coastguard Worker + "The following dtypes worked in forward but are not listed by the OpInfo: " 1536*da0073e9SAndroid Build Coastguard Worker + f"{supported_but_unclaimed_forward}.\n" 1537*da0073e9SAndroid Build Coastguard Worker ) 1538*da0073e9SAndroid Build Coastguard Worker if len(supported_but_unclaimed_backward) > 0: 1539*da0073e9SAndroid Build Coastguard Worker msg = ( 1540*da0073e9SAndroid Build Coastguard Worker msg 1541*da0073e9SAndroid Build Coastguard Worker + "The following dtypes worked in backward but are not listed by the OpInfo: " 1542*da0073e9SAndroid Build Coastguard Worker + f"{supported_but_unclaimed_backward}.\n" 1543*da0073e9SAndroid Build Coastguard Worker ) 1544*da0073e9SAndroid Build Coastguard Worker if len(claimed_but_unsupported_forward) > 0: 1545*da0073e9SAndroid Build Coastguard Worker msg = ( 1546*da0073e9SAndroid Build Coastguard Worker msg 1547*da0073e9SAndroid Build Coastguard Worker + "The following dtypes did not work in forward but are listed by the OpInfo: " 1548*da0073e9SAndroid Build Coastguard Worker + f"{claimed_but_unsupported_forward}.\n" 1549*da0073e9SAndroid Build Coastguard Worker ) 1550*da0073e9SAndroid Build Coastguard Worker if len(claimed_but_unsupported_backward) > 0: 1551*da0073e9SAndroid Build Coastguard Worker msg = ( 1552*da0073e9SAndroid Build Coastguard Worker msg 1553*da0073e9SAndroid Build Coastguard Worker + "The following dtypes did not work in backward " 1554*da0073e9SAndroid Build Coastguard Worker + f"but are listed by the OpInfo: {claimed_but_unsupported_backward}.\n" 1555*da0073e9SAndroid Build Coastguard Worker ) 1556*da0073e9SAndroid Build Coastguard Worker 1557*da0073e9SAndroid Build Coastguard Worker all_claimed_but_unsupported = set.union( 1558*da0073e9SAndroid Build Coastguard Worker claimed_but_unsupported_backward, claimed_but_unsupported_forward 1559*da0073e9SAndroid Build Coastguard Worker ) 1560*da0073e9SAndroid Build Coastguard Worker if all_claimed_but_unsupported: 1561*da0073e9SAndroid Build Coastguard Worker msg += "Unexpected failures raised the following errors:\n" 1562*da0073e9SAndroid Build Coastguard Worker for dtype in all_claimed_but_unsupported: 1563*da0073e9SAndroid Build Coastguard Worker msg += f"{dtype} - {dtype_error[dtype]}\n" 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker self.fail(msg) 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker # Validates that each OpInfo that sets promotes_int_to_float=True does as it says 1568*da0073e9SAndroid Build Coastguard Worker @skipMeta 1569*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypesAnd(["hpu"]) 1570*da0073e9SAndroid Build Coastguard Worker @ops( 1571*da0073e9SAndroid Build Coastguard Worker (op for op in op_db if op.promotes_int_to_float), 1572*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=integral_types_and(torch.bool), 1573*da0073e9SAndroid Build Coastguard Worker ) 1574*da0073e9SAndroid Build Coastguard Worker def test_promotes_int_to_float(self, device, dtype, op): 1575*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device, dtype): 1576*da0073e9SAndroid Build Coastguard Worker output = op(sample.input, *sample.args, **sample.kwargs) 1577*da0073e9SAndroid Build Coastguard Worker if not output.dtype.is_floating_point: 1578*da0073e9SAndroid Build Coastguard Worker self.fail( 1579*da0073e9SAndroid Build Coastguard Worker f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}." 1580*da0073e9SAndroid Build Coastguard Worker ) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 1584*da0073e9SAndroid Build Coastguard Workerclass TestCompositeCompliance(TestCase): 1585*da0073e9SAndroid Build Coastguard Worker # Checks if the operator (if it is composite) is written to support most 1586*da0073e9SAndroid Build Coastguard Worker # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance" 1587*da0073e9SAndroid Build Coastguard Worker # in aten/src/ATen/native/README.md for more details 1588*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1589*da0073e9SAndroid Build Coastguard Worker IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" 1590*da0073e9SAndroid Build Coastguard Worker ) 1591*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float,)) 1592*da0073e9SAndroid Build Coastguard Worker def test_operator(self, device, dtype, op): 1593*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1596*da0073e9SAndroid Build Coastguard Worker args = [sample.input] + list(sample.args) 1597*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs 1598*da0073e9SAndroid Build Coastguard Worker composite_compliance.check_with_mode(op, args, kwargs, self.assertEqual) 1599*da0073e9SAndroid Build Coastguard Worker composite_compliance.check_all_permutations( 1600*da0073e9SAndroid Build Coastguard Worker op, args, kwargs, self.assertEqual 1601*da0073e9SAndroid Build Coastguard Worker ) 1602*da0073e9SAndroid Build Coastguard Worker 1603*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1604*da0073e9SAndroid Build Coastguard Worker IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" 1605*da0073e9SAndroid Build Coastguard Worker ) 1606*da0073e9SAndroid Build Coastguard Worker @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) 1607*da0073e9SAndroid Build Coastguard Worker def test_backward(self, device, dtype, op): 1608*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=True) 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1611*da0073e9SAndroid Build Coastguard Worker args = [sample.input] + list(sample.args) 1612*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs 1613*da0073e9SAndroid Build Coastguard Worker # We pass assertEqual so that decorators like `toleranceOverride` 1614*da0073e9SAndroid Build Coastguard Worker # actually work (otherwise they silently do nothing!) 1615*da0073e9SAndroid Build Coastguard Worker composite_compliance.check_backward_formula( 1616*da0073e9SAndroid Build Coastguard Worker op.get_op(), 1617*da0073e9SAndroid Build Coastguard Worker args, 1618*da0073e9SAndroid Build Coastguard Worker kwargs, 1619*da0073e9SAndroid Build Coastguard Worker sample.output_process_fn_grad, 1620*da0073e9SAndroid Build Coastguard Worker op.gradcheck_wrapper, 1621*da0073e9SAndroid Build Coastguard Worker self.assertEqual, 1622*da0073e9SAndroid Build Coastguard Worker ) 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1625*da0073e9SAndroid Build Coastguard Worker IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode" 1626*da0073e9SAndroid Build Coastguard Worker ) 1627*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float,)) 1628*da0073e9SAndroid Build Coastguard Worker def test_forward_ad(self, device, dtype, op): 1629*da0073e9SAndroid Build Coastguard Worker if torch.float not in op.supported_backward_dtypes(device): 1630*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("Does not support autograd") 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker if not op.supports_forward_ad: 1633*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("Does not support forward_ad") 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=True) 1636*da0073e9SAndroid Build Coastguard Worker 1637*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1638*da0073e9SAndroid Build Coastguard Worker args = [sample.input] + list(sample.args) 1639*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs 1640*da0073e9SAndroid Build Coastguard Worker # We pass assertEqual so that decorators like `toleranceOverride` 1641*da0073e9SAndroid Build Coastguard Worker # actually work (otherwise they silently do nothing!) 1642*da0073e9SAndroid Build Coastguard Worker composite_compliance.check_forward_ad_formula( 1643*da0073e9SAndroid Build Coastguard Worker op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual 1644*da0073e9SAndroid Build Coastguard Worker ) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float,)) 1647*da0073e9SAndroid Build Coastguard Worker def test_cow_input(self, device, dtype, op): 1648*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker def is_strided_tensor(arg): 1651*da0073e9SAndroid Build Coastguard Worker return torch.is_tensor(arg) and arg.layout == torch.strided 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker def check_ignore_materialize(idx_or_kw, allow_list): 1654*da0073e9SAndroid Build Coastguard Worker return (allow_list is not None) and (idx_or_kw in allow_list) 1655*da0073e9SAndroid Build Coastguard Worker 1656*da0073e9SAndroid Build Coastguard Worker def check_cow_input( 1657*da0073e9SAndroid Build Coastguard Worker arg, 1658*da0073e9SAndroid Build Coastguard Worker arg_copy, 1659*da0073e9SAndroid Build Coastguard Worker idx_or_kw, 1660*da0073e9SAndroid Build Coastguard Worker backward_or_forward="forward", 1661*da0073e9SAndroid Build Coastguard Worker supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward, 1662*da0073e9SAndroid Build Coastguard Worker allow_list=op.allow_cow_input_materialize_forward, 1663*da0073e9SAndroid Build Coastguard Worker ): 1664*da0073e9SAndroid Build Coastguard Worker arg_name = ( 1665*da0073e9SAndroid Build Coastguard Worker f"Argument {idx_or_kw}" 1666*da0073e9SAndroid Build Coastguard Worker if isinstance(idx_or_kw, int) 1667*da0073e9SAndroid Build Coastguard Worker else f"Keyword argument '{idx_or_kw}'" 1668*da0073e9SAndroid Build Coastguard Worker ) + f" during {backward_or_forward} call" 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker if is_strided_tensor(arg): 1671*da0073e9SAndroid Build Coastguard Worker is_cow = torch._C._is_cow_tensor(arg) 1672*da0073e9SAndroid Build Coastguard Worker 1673*da0073e9SAndroid Build Coastguard Worker if supports_cow_input_no_materialize and not check_ignore_materialize( 1674*da0073e9SAndroid Build Coastguard Worker idx_or_kw, allow_list 1675*da0073e9SAndroid Build Coastguard Worker ): 1676*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1677*da0073e9SAndroid Build Coastguard Worker is_cow, 1678*da0073e9SAndroid Build Coastguard Worker msg=( 1679*da0073e9SAndroid Build Coastguard Worker f"{arg_name} unexpectedly materializes. " 1680*da0073e9SAndroid Build Coastguard Worker f"Either set `supports_cow_input_no_materialize_{backward_or_forward}=False` " 1681*da0073e9SAndroid Build Coastguard Worker "in this operation's OpInfo, add the arg to the OpInfo's " 1682*da0073e9SAndroid Build Coastguard Worker f"`allow_cow_input_materialize_{backward_or_forward}` list, or change the " 1683*da0073e9SAndroid Build Coastguard Worker "implementation to avoid materialization." 1684*da0073e9SAndroid Build Coastguard Worker ), 1685*da0073e9SAndroid Build Coastguard Worker ) 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker if is_cow: 1688*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1689*da0073e9SAndroid Build Coastguard Worker torch.allclose(arg, arg_copy, rtol=0, atol=0, equal_nan=True), 1690*da0073e9SAndroid Build Coastguard Worker msg=( 1691*da0073e9SAndroid Build Coastguard Worker f"{arg_name} avoided materialization, " 1692*da0073e9SAndroid Build Coastguard Worker "but the operation mutated its data." 1693*da0073e9SAndroid Build Coastguard Worker ), 1694*da0073e9SAndroid Build Coastguard Worker ) 1695*da0073e9SAndroid Build Coastguard Worker 1696*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1697*da0073e9SAndroid Build Coastguard Worker args_raw = [sample.input] + list(sample.args) 1698*da0073e9SAndroid Build Coastguard Worker kwargs_raw = sample.kwargs 1699*da0073e9SAndroid Build Coastguard Worker args_copy = [] 1700*da0073e9SAndroid Build Coastguard Worker args = [] 1701*da0073e9SAndroid Build Coastguard Worker kwargs_copy = {} 1702*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1703*da0073e9SAndroid Build Coastguard Worker 1704*da0073e9SAndroid Build Coastguard Worker # Convert strided tensor inputs to COW tensors and make copies of 1705*da0073e9SAndroid Build Coastguard Worker # all inputs 1706*da0073e9SAndroid Build Coastguard Worker for idx, arg in enumerate(args_raw): 1707*da0073e9SAndroid Build Coastguard Worker if is_strided_tensor(arg): 1708*da0073e9SAndroid Build Coastguard Worker args_copy.append(arg.clone().detach()) 1709*da0073e9SAndroid Build Coastguard Worker args.append(torch._lazy_clone(arg)) 1710*da0073e9SAndroid Build Coastguard Worker else: 1711*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(arg): 1712*da0073e9SAndroid Build Coastguard Worker args_copy.append(arg.clone().detach()) 1713*da0073e9SAndroid Build Coastguard Worker else: 1714*da0073e9SAndroid Build Coastguard Worker args_copy.append(copy.deepcopy(arg)) 1715*da0073e9SAndroid Build Coastguard Worker args.append(arg) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker for kw, arg in kwargs_raw.items(): 1718*da0073e9SAndroid Build Coastguard Worker if is_strided_tensor(arg): 1719*da0073e9SAndroid Build Coastguard Worker kwargs_copy[kw] = arg.clone().detach() 1720*da0073e9SAndroid Build Coastguard Worker kwargs[kw] = torch._lazy_clone(arg) 1721*da0073e9SAndroid Build Coastguard Worker else: 1722*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(arg): 1723*da0073e9SAndroid Build Coastguard Worker kwargs_copy[kw] = arg.clone().detach() 1724*da0073e9SAndroid Build Coastguard Worker else: 1725*da0073e9SAndroid Build Coastguard Worker kwargs_copy[kw] = copy.deepcopy(arg) 1726*da0073e9SAndroid Build Coastguard Worker kwargs[kw] = arg 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker leaf_tensors = composite_compliance.gather_leaf_tensors(args, kwargs) 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker # Call forward op 1731*da0073e9SAndroid Build Coastguard Worker results_raw = op.get_op()(*args, **kwargs) 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker # Check that COW inputs remain COW after the forward op is executed 1734*da0073e9SAndroid Build Coastguard Worker for idx, arg in enumerate(args): 1735*da0073e9SAndroid Build Coastguard Worker check_cow_input(arg, args_copy[idx], idx) 1736*da0073e9SAndroid Build Coastguard Worker 1737*da0073e9SAndroid Build Coastguard Worker for kw, arg in kwargs.items(): 1738*da0073e9SAndroid Build Coastguard Worker check_cow_input(arg, kwargs_copy[kw], kw) 1739*da0073e9SAndroid Build Coastguard Worker 1740*da0073e9SAndroid Build Coastguard Worker # Call backward op if it is supported. This part of the test is 1741*da0073e9SAndroid Build Coastguard Worker # based on `composite_compliance.check_backward_formula` 1742*da0073e9SAndroid Build Coastguard Worker if ( 1743*da0073e9SAndroid Build Coastguard Worker op.supports_autograd 1744*da0073e9SAndroid Build Coastguard Worker and len(leaf_tensors) > 0 1745*da0073e9SAndroid Build Coastguard Worker and not op.skip_cow_input_backward 1746*da0073e9SAndroid Build Coastguard Worker ): 1747*da0073e9SAndroid Build Coastguard Worker if sample.output_process_fn_grad is not None: 1748*da0073e9SAndroid Build Coastguard Worker results_raw = sample.output_process_fn_grad(results_raw) 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker leaf_results = pytree.tree_leaves(results_raw) 1751*da0073e9SAndroid Build Coastguard Worker results = [ 1752*da0073e9SAndroid Build Coastguard Worker r 1753*da0073e9SAndroid Build Coastguard Worker for r in leaf_results 1754*da0073e9SAndroid Build Coastguard Worker if isinstance(r, torch.Tensor) and r.requires_grad 1755*da0073e9SAndroid Build Coastguard Worker ] 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Worker all_results_strided = all( 1758*da0073e9SAndroid Build Coastguard Worker is_strided_tensor(result) for result in results 1759*da0073e9SAndroid Build Coastguard Worker ) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker # Only test backward if the results are strided tensors 1762*da0073e9SAndroid Build Coastguard Worker if all_results_strided: 1763*da0073e9SAndroid Build Coastguard Worker output_grads_raw = [ 1764*da0073e9SAndroid Build Coastguard Worker torch.ones(r.shape, device=r.device, dtype=r.dtype) 1765*da0073e9SAndroid Build Coastguard Worker for r in results 1766*da0073e9SAndroid Build Coastguard Worker ] 1767*da0073e9SAndroid Build Coastguard Worker output_grads_copy = [] 1768*da0073e9SAndroid Build Coastguard Worker output_grads = [] 1769*da0073e9SAndroid Build Coastguard Worker 1770*da0073e9SAndroid Build Coastguard Worker # Convert output grads to COW tensors and make copies 1771*da0073e9SAndroid Build Coastguard Worker for output_grad in output_grads_raw: 1772*da0073e9SAndroid Build Coastguard Worker output_grads_copy.append(output_grad.clone().detach()) 1773*da0073e9SAndroid Build Coastguard Worker output_grads.append(torch._lazy_clone(output_grad)) 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker input_grads = torch.autograd.grad( 1776*da0073e9SAndroid Build Coastguard Worker results, 1777*da0073e9SAndroid Build Coastguard Worker leaf_tensors, 1778*da0073e9SAndroid Build Coastguard Worker output_grads, 1779*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 1780*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 1781*da0073e9SAndroid Build Coastguard Worker ) 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker # Check that COW inputs remain COW after the backward op is executed 1784*da0073e9SAndroid Build Coastguard Worker for idx, arg in enumerate(args): 1785*da0073e9SAndroid Build Coastguard Worker check_cow_input( 1786*da0073e9SAndroid Build Coastguard Worker arg, 1787*da0073e9SAndroid Build Coastguard Worker args_copy[idx], 1788*da0073e9SAndroid Build Coastguard Worker idx, 1789*da0073e9SAndroid Build Coastguard Worker backward_or_forward="backward", 1790*da0073e9SAndroid Build Coastguard Worker supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, 1791*da0073e9SAndroid Build Coastguard Worker allow_list=op.allow_cow_input_materialize_backward, 1792*da0073e9SAndroid Build Coastguard Worker ) 1793*da0073e9SAndroid Build Coastguard Worker 1794*da0073e9SAndroid Build Coastguard Worker # Check that COW inputs remain COW after the backward op is executed 1795*da0073e9SAndroid Build Coastguard Worker for idx, output_grad in enumerate(output_grads): 1796*da0073e9SAndroid Build Coastguard Worker check_cow_input( 1797*da0073e9SAndroid Build Coastguard Worker output_grad, 1798*da0073e9SAndroid Build Coastguard Worker output_grads_copy[idx], 1799*da0073e9SAndroid Build Coastguard Worker f"output grad {idx}", 1800*da0073e9SAndroid Build Coastguard Worker backward_or_forward="backward", 1801*da0073e9SAndroid Build Coastguard Worker supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, 1802*da0073e9SAndroid Build Coastguard Worker allow_list=op.allow_cow_input_materialize_backward, 1803*da0073e9SAndroid Build Coastguard Worker ) 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float,)) 1806*da0073e9SAndroid Build Coastguard Worker def test_view_replay(self, device, dtype, op): 1807*da0073e9SAndroid Build Coastguard Worker def _assert_match_metadata(a, b): 1808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), b.size()) 1809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.stride(), b.stride()) 1810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.storage_offset(), b.storage_offset()) 1811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, b.device) 1812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, b.dtype) 1813*da0073e9SAndroid Build Coastguard Worker 1814*da0073e9SAndroid Build Coastguard Worker # ensure view replay is enabled 1815*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(True): 1816*da0073e9SAndroid Build Coastguard Worker for sample in op.sample_inputs(device, dtype, requires_grad=False): 1817*da0073e9SAndroid Build Coastguard Worker inp = sample.input 1818*da0073e9SAndroid Build Coastguard Worker outs = op(inp, *sample.args, **sample.kwargs) 1819*da0073e9SAndroid Build Coastguard Worker if not isinstance(outs, (tuple, List)): 1820*da0073e9SAndroid Build Coastguard Worker outs = [outs] 1821*da0073e9SAndroid Build Coastguard Worker 1822*da0073e9SAndroid Build Coastguard Worker # for all outputs that are views of the input, we should be able to replay the 1823*da0073e9SAndroid Build Coastguard Worker # forward and reverse views via a functioning view_func() / rev_view_func(). 1824*da0073e9SAndroid Build Coastguard Worker for out in outs: 1825*da0073e9SAndroid Build Coastguard Worker if not ( 1826*da0073e9SAndroid Build Coastguard Worker isinstance(out, torch.Tensor) 1827*da0073e9SAndroid Build Coastguard Worker and out._is_view() 1828*da0073e9SAndroid Build Coastguard Worker and out._base is inp 1829*da0073e9SAndroid Build Coastguard Worker ): 1830*da0073e9SAndroid Build Coastguard Worker continue 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker # forward view_func 1833*da0073e9SAndroid Build Coastguard Worker new_inp = inp.clone() 1834*da0073e9SAndroid Build Coastguard Worker _assert_match_metadata(new_inp, inp) 1835*da0073e9SAndroid Build Coastguard Worker new_out = out._view_func_unsafe(new_inp) 1836*da0073e9SAndroid Build Coastguard Worker _assert_match_metadata(new_out, out) 1837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_out, out) 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker # reverse view_func 1840*da0073e9SAndroid Build Coastguard Worker new_out = out.detach() 1841*da0073e9SAndroid Build Coastguard Worker new_inp = out._rev_view_func_unsafe(new_out) 1842*da0073e9SAndroid Build Coastguard Worker _assert_match_metadata(new_inp, inp) 1843*da0073e9SAndroid Build Coastguard Worker self.assertTrue(new_inp._is_view()) 1844*da0073e9SAndroid Build Coastguard Worker self.assertTrue(new_inp._base is new_out) 1845*da0073e9SAndroid Build Coastguard Worker 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 1848*da0073e9SAndroid Build Coastguard Workerclass TestMathBits(TestCase): 1849*da0073e9SAndroid Build Coastguard Worker # Tests that 1850*da0073e9SAndroid Build Coastguard Worker # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors 1851*da0073e9SAndroid Build Coastguard Worker # produces the same value 1852*da0073e9SAndroid Build Coastguard Worker # 2. The gradients are same in both cases mentioned in (1) 1853*da0073e9SAndroid Build Coastguard Worker # 3. If the operator's inplace variant is supported, tests that the inplace operation 1854*da0073e9SAndroid Build Coastguard Worker # produces the correct value when called on a conjugate/negative view tensor and that the output 1855*da0073e9SAndroid Build Coastguard Worker # has its conj/neg bit set to true 1856*da0073e9SAndroid Build Coastguard Worker # This test only runs for C -> R and C -> C functions 1857*da0073e9SAndroid Build Coastguard Worker # TODO: add tests for `R->C` functions 1858*da0073e9SAndroid Build Coastguard Worker # Note: This test runs for functions that take both tensors and tensorlists as input. 1859*da0073e9SAndroid Build Coastguard Worker def _test_math_view( 1860*da0073e9SAndroid Build Coastguard Worker self, 1861*da0073e9SAndroid Build Coastguard Worker device, 1862*da0073e9SAndroid Build Coastguard Worker dtype, 1863*da0073e9SAndroid Build Coastguard Worker op, 1864*da0073e9SAndroid Build Coastguard Worker samples, 1865*da0073e9SAndroid Build Coastguard Worker math_op_physical, 1866*da0073e9SAndroid Build Coastguard Worker math_op_view, 1867*da0073e9SAndroid Build Coastguard Worker is_bit_set, 1868*da0073e9SAndroid Build Coastguard Worker out_type, 1869*da0073e9SAndroid Build Coastguard Worker ): 1870*da0073e9SAndroid Build Coastguard Worker inplace_variant = op.inplace_variant 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker # helper function to clone and conjugate/negate the input if its a tensor 1873*da0073e9SAndroid Build Coastguard Worker # else clone the sequence and conjugate/negate the first element in the sequence 1874*da0073e9SAndroid Build Coastguard Worker # If a requires_grad argument is provided the tensor being conjugated/negated will 1875*da0073e9SAndroid Build Coastguard Worker # have its requires_grad set to that value. 1876*da0073e9SAndroid Build Coastguard Worker def clone_and_perform_view(input, **kwargs): 1877*da0073e9SAndroid Build Coastguard Worker if isinstance(input, torch.Tensor): 1878*da0073e9SAndroid Build Coastguard Worker requires_grad = kwargs.get("requires_grad", input.requires_grad) 1879*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1880*da0073e9SAndroid Build Coastguard Worker # Ensure view represents the original sample input 1881*da0073e9SAndroid Build Coastguard Worker input = math_op_physical(input) 1882*da0073e9SAndroid Build Coastguard Worker # Note: .conj() is not called under no_grad mode since it's not allowed to modify a 1883*da0073e9SAndroid Build Coastguard Worker # view created in no_grad mode. Here it's ok to do so, so as a workaround we call conj 1884*da0073e9SAndroid Build Coastguard Worker # before resetting the requires_grad field for input 1885*da0073e9SAndroid Build Coastguard Worker input = math_op_view(input) 1886*da0073e9SAndroid Build Coastguard Worker assert input.is_leaf 1887*da0073e9SAndroid Build Coastguard Worker return input.requires_grad_(requires_grad) 1888*da0073e9SAndroid Build Coastguard Worker 1889*da0073e9SAndroid Build Coastguard Worker if isinstance(input, Sequence): 1890*da0073e9SAndroid Build Coastguard Worker out = list(map(clone_input_helper, input)) 1891*da0073e9SAndroid Build Coastguard Worker out[0] = clone_and_perform_view(out[0]) 1892*da0073e9SAndroid Build Coastguard Worker return tuple(out) 1893*da0073e9SAndroid Build Coastguard Worker 1894*da0073e9SAndroid Build Coastguard Worker for sample in samples: 1895*da0073e9SAndroid Build Coastguard Worker tensor = ( 1896*da0073e9SAndroid Build Coastguard Worker sample.input 1897*da0073e9SAndroid Build Coastguard Worker if isinstance(sample.input, torch.Tensor) 1898*da0073e9SAndroid Build Coastguard Worker else sample.input[0] 1899*da0073e9SAndroid Build Coastguard Worker ) 1900*da0073e9SAndroid Build Coastguard Worker cloned1 = clone_and_perform_view(sample.input) 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker # Computes function forward value with a physically conjugated/negated tensor and 1903*da0073e9SAndroid Build Coastguard Worker # a conj/neg view tensor and verifies that the output in both case are equal. 1904*da0073e9SAndroid Build Coastguard Worker expected_forward = op(sample.input, *sample.args, **sample.kwargs) 1905*da0073e9SAndroid Build Coastguard Worker forward_with_mathview = op(cloned1, *sample.args, **sample.kwargs) 1906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_forward, forward_with_mathview) 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker # If the op has an inplace variant, and the input doesn't require broadcasting 1909*da0073e9SAndroid Build Coastguard Worker # and has the same dtype as output, verify that the inplace operation on a conjugated/negated 1910*da0073e9SAndroid Build Coastguard Worker # input produces correct output, and the output tensor has the conj/neg bit set to True 1911*da0073e9SAndroid Build Coastguard Worker if inplace_variant is not None and not sample.broadcasts_input: 1912*da0073e9SAndroid Build Coastguard Worker cloned2 = clone_and_perform_view(tensor, requires_grad=False) 1913*da0073e9SAndroid Build Coastguard Worker if ( 1914*da0073e9SAndroid Build Coastguard Worker isinstance(expected_forward, torch.Tensor) 1915*da0073e9SAndroid Build Coastguard Worker and expected_forward.dtype is tensor.dtype 1916*da0073e9SAndroid Build Coastguard Worker ): 1917*da0073e9SAndroid Build Coastguard Worker inplace_forward = inplace_variant( 1918*da0073e9SAndroid Build Coastguard Worker cloned2, *sample.args, **sample.kwargs 1919*da0073e9SAndroid Build Coastguard Worker ) 1920*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_bit_set(inplace_forward)) 1921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inplace_forward, expected_forward) 1922*da0073e9SAndroid Build Coastguard Worker 1923*da0073e9SAndroid Build Coastguard Worker # TODO: backward consistency only supported for single tensor outputs 1924*da0073e9SAndroid Build Coastguard Worker # TODO: backward consistency only checked on sample.input, not all 1925*da0073e9SAndroid Build Coastguard Worker # tensor inputs 1926*da0073e9SAndroid Build Coastguard Worker # TODO: update to handle checking grads of all tensor inputs as 1927*da0073e9SAndroid Build Coastguard Worker # derived from each tensor output 1928*da0073e9SAndroid Build Coastguard Worker if ( 1929*da0073e9SAndroid Build Coastguard Worker isinstance(expected_forward, torch.Tensor) 1930*da0073e9SAndroid Build Coastguard Worker and expected_forward.requires_grad 1931*da0073e9SAndroid Build Coastguard Worker ): 1932*da0073e9SAndroid Build Coastguard Worker output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x) 1933*da0073e9SAndroid Build Coastguard Worker expected_forward = output_process_fn_grad(expected_forward) 1934*da0073e9SAndroid Build Coastguard Worker forward_with_mathview = output_process_fn_grad(forward_with_mathview) 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker tensor = ( 1937*da0073e9SAndroid Build Coastguard Worker sample.input 1938*da0073e9SAndroid Build Coastguard Worker if isinstance(sample.input, torch.Tensor) 1939*da0073e9SAndroid Build Coastguard Worker else sample.input[0] 1940*da0073e9SAndroid Build Coastguard Worker ) 1941*da0073e9SAndroid Build Coastguard Worker expected_forward.sum().abs().backward(retain_graph=True) 1942*da0073e9SAndroid Build Coastguard Worker forward_with_mathview.sum().abs().backward(retain_graph=True) 1943*da0073e9SAndroid Build Coastguard Worker if tensor.grad is not None: 1944*da0073e9SAndroid Build Coastguard Worker cloned1_tensor = ( 1945*da0073e9SAndroid Build Coastguard Worker cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0] 1946*da0073e9SAndroid Build Coastguard Worker ) 1947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.grad, cloned1_tensor.grad) 1948*da0073e9SAndroid Build Coastguard Worker 1949*da0073e9SAndroid Build Coastguard Worker tensor.grad, cloned1_tensor.grad = None, None 1950*da0073e9SAndroid Build Coastguard Worker 1951*da0073e9SAndroid Build Coastguard Worker # a repeat of the above test if output is not complex valued 1952*da0073e9SAndroid Build Coastguard Worker if out_type(expected_forward): 1953*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(expected_forward) 1954*da0073e9SAndroid Build Coastguard Worker expected_forward.backward(grad) 1955*da0073e9SAndroid Build Coastguard Worker forward_with_mathview.backward( 1956*da0073e9SAndroid Build Coastguard Worker math_op_view(math_op_physical(grad)) 1957*da0073e9SAndroid Build Coastguard Worker ) 1958*da0073e9SAndroid Build Coastguard Worker 1959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.grad, cloned1_tensor.grad) 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) 1962*da0073e9SAndroid Build Coastguard Worker def test_conj_view(self, device, dtype, op): 1963*da0073e9SAndroid Build Coastguard Worker if not op.test_conjugated_samples: 1964*da0073e9SAndroid Build Coastguard Worker self.skipTest("Operation doesn't support conjugated inputs.") 1965*da0073e9SAndroid Build Coastguard Worker math_op_physical = torch.conj_physical 1966*da0073e9SAndroid Build Coastguard Worker math_op_view = torch.conj 1967*da0073e9SAndroid Build Coastguard Worker _requires_grad = torch.cfloat in op.supported_backward_dtypes( 1968*da0073e9SAndroid Build Coastguard Worker torch.device(device).type 1969*da0073e9SAndroid Build Coastguard Worker ) 1970*da0073e9SAndroid Build Coastguard Worker is_bit_set = torch.is_conj 1971*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) 1972*da0073e9SAndroid Build Coastguard Worker self._test_math_view( 1973*da0073e9SAndroid Build Coastguard Worker device, 1974*da0073e9SAndroid Build Coastguard Worker dtype, 1975*da0073e9SAndroid Build Coastguard Worker op, 1976*da0073e9SAndroid Build Coastguard Worker samples, 1977*da0073e9SAndroid Build Coastguard Worker math_op_physical, 1978*da0073e9SAndroid Build Coastguard Worker math_op_view, 1979*da0073e9SAndroid Build Coastguard Worker is_bit_set, 1980*da0073e9SAndroid Build Coastguard Worker torch.is_complex, 1981*da0073e9SAndroid Build Coastguard Worker ) 1982*da0073e9SAndroid Build Coastguard Worker 1983*da0073e9SAndroid Build Coastguard Worker @ops(ops_and_refs, allowed_dtypes=(torch.double,)) 1984*da0073e9SAndroid Build Coastguard Worker def test_neg_view(self, device, dtype, op): 1985*da0073e9SAndroid Build Coastguard Worker if not op.test_neg_view: 1986*da0073e9SAndroid Build Coastguard Worker self.skipTest("Operation not tested with tensors with negative bit.") 1987*da0073e9SAndroid Build Coastguard Worker math_op_physical = torch.neg 1988*da0073e9SAndroid Build Coastguard Worker math_op_view = torch._neg_view 1989*da0073e9SAndroid Build Coastguard Worker is_bit_set = torch.is_neg 1990*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) 1991*da0073e9SAndroid Build Coastguard Worker self._test_math_view( 1992*da0073e9SAndroid Build Coastguard Worker device, 1993*da0073e9SAndroid Build Coastguard Worker dtype, 1994*da0073e9SAndroid Build Coastguard Worker op, 1995*da0073e9SAndroid Build Coastguard Worker samples, 1996*da0073e9SAndroid Build Coastguard Worker math_op_physical, 1997*da0073e9SAndroid Build Coastguard Worker math_op_view, 1998*da0073e9SAndroid Build Coastguard Worker is_bit_set, 1999*da0073e9SAndroid Build Coastguard Worker lambda x: True, 2000*da0073e9SAndroid Build Coastguard Worker ) 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Worker @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) 2003*da0073e9SAndroid Build Coastguard Worker def test_neg_conj_view(self, device, dtype, op): 2004*da0073e9SAndroid Build Coastguard Worker if not op.test_neg_view: 2005*da0073e9SAndroid Build Coastguard Worker self.skipTest("Operation not tested with tensors with negative bit.") 2006*da0073e9SAndroid Build Coastguard Worker if not op.test_conjugated_samples: 2007*da0073e9SAndroid Build Coastguard Worker self.skipTest("Operation doesn't support conjugated inputs.") 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker def math_op_physical(x): 2010*da0073e9SAndroid Build Coastguard Worker return -x.conj_physical() 2011*da0073e9SAndroid Build Coastguard Worker 2012*da0073e9SAndroid Build Coastguard Worker def math_op_view(x): 2013*da0073e9SAndroid Build Coastguard Worker return torch._neg_view(x).conj() 2014*da0073e9SAndroid Build Coastguard Worker 2015*da0073e9SAndroid Build Coastguard Worker def is_bit_set(x): 2016*da0073e9SAndroid Build Coastguard Worker return torch.is_neg(x) and torch.is_conj(x) 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker _requires_grad = dtype in op.supported_backward_dtypes( 2019*da0073e9SAndroid Build Coastguard Worker torch.device(device).type 2020*da0073e9SAndroid Build Coastguard Worker ) 2021*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) 2022*da0073e9SAndroid Build Coastguard Worker # Only test one sample 2023*da0073e9SAndroid Build Coastguard Worker samples = itertools.islice(samples, 1) 2024*da0073e9SAndroid Build Coastguard Worker self._test_math_view( 2025*da0073e9SAndroid Build Coastguard Worker device, 2026*da0073e9SAndroid Build Coastguard Worker dtype, 2027*da0073e9SAndroid Build Coastguard Worker op, 2028*da0073e9SAndroid Build Coastguard Worker samples, 2029*da0073e9SAndroid Build Coastguard Worker math_op_physical, 2030*da0073e9SAndroid Build Coastguard Worker math_op_view, 2031*da0073e9SAndroid Build Coastguard Worker is_bit_set, 2032*da0073e9SAndroid Build Coastguard Worker torch.is_complex, 2033*da0073e9SAndroid Build Coastguard Worker ) 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker 2036*da0073e9SAndroid Build Coastguard Worker# input strides and size may have been altered due to the result of an inplace op 2037*da0073e9SAndroid Build Coastguard Workerdef check_inplace_view(func, input, rs, input_size, input_strides): 2038*da0073e9SAndroid Build Coastguard Worker if func is None: 2039*da0073e9SAndroid Build Coastguard Worker return 2040*da0073e9SAndroid Build Coastguard Worker # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out 2041*da0073e9SAndroid Build Coastguard Worker # which mutate not necessarily the first input. 2042*da0073e9SAndroid Build Coastguard Worker if isinstance(rs, torch.Tensor) and rs is input: 2043*da0073e9SAndroid Build Coastguard Worker unequal_size = rs.size() != input_size 2044*da0073e9SAndroid Build Coastguard Worker unequal_strides = rs.stride() != input_strides 2045*da0073e9SAndroid Build Coastguard Worker # resize_ should probably have inplace_view tag. Not adding the tag since it 2046*da0073e9SAndroid Build Coastguard Worker # breaks some codegen logic 2047*da0073e9SAndroid Build Coastguard Worker if unequal_size or unequal_strides: 2048*da0073e9SAndroid Build Coastguard Worker if isinstance(func, torch._ops.OpOverloadPacket): 2049*da0073e9SAndroid Build Coastguard Worker func = func.default 2050*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/78759 2051*da0073e9SAndroid Build Coastguard Worker if func is not torch.ops.aten.resize_.default: 2052*da0073e9SAndroid Build Coastguard Worker # TODO: use self.assertIn when we have separate tests for each tag 2053*da0073e9SAndroid Build Coastguard Worker assert torch.Tag.inplace_view in func.tags 2054*da0073e9SAndroid Build Coastguard Worker 2055*da0073e9SAndroid Build Coastguard Worker 2056*da0073e9SAndroid Build Coastguard Worker# A mode that when enabled runs correctness checks to ensure 2057*da0073e9SAndroid Build Coastguard Worker# that operators have expected tags based on their input and 2058*da0073e9SAndroid Build Coastguard Worker# output tensor properties 2059*da0073e9SAndroid Build Coastguard Workerclass TestTagsMode(TorchDispatchMode): 2060*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2061*da0073e9SAndroid Build Coastguard Worker if isinstance(args[0], torch.Tensor): 2062*da0073e9SAndroid Build Coastguard Worker old_size = args[0].size() 2063*da0073e9SAndroid Build Coastguard Worker old_stride = args[0].stride() 2064*da0073e9SAndroid Build Coastguard Worker rs = func(*args, **kwargs) 2065*da0073e9SAndroid Build Coastguard Worker check_inplace_view(func, args[0], rs, old_size, old_stride) 2066*da0073e9SAndroid Build Coastguard Worker else: 2067*da0073e9SAndroid Build Coastguard Worker rs = func(*args, **kwargs) 2068*da0073e9SAndroid Build Coastguard Worker return rs 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker 2071*da0073e9SAndroid Build Coastguard Worker# Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags` 2072*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 2073*da0073e9SAndroid Build Coastguard Workerclass TestTags(TestCase): 2074*da0073e9SAndroid Build Coastguard Worker @onlyCPU 2075*da0073e9SAndroid Build Coastguard Worker @ops(ops_and_refs, dtypes=OpDTypes.any_one) 2076*da0073e9SAndroid Build Coastguard Worker def test_tags(self, device, dtype, op): 2077*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 2078*da0073e9SAndroid Build Coastguard Worker for sample in samples: 2079*da0073e9SAndroid Build Coastguard Worker # TODO: Test tags for ops that return a list of tensors 2080*da0073e9SAndroid Build Coastguard Worker input = sample.input 2081*da0073e9SAndroid Build Coastguard Worker if isinstance(input, torch.Tensor): 2082*da0073e9SAndroid Build Coastguard Worker old_size = input.size() 2083*da0073e9SAndroid Build Coastguard Worker old_stride = input.stride() 2084*da0073e9SAndroid Build Coastguard Worker with TestTagsMode(): 2085*da0073e9SAndroid Build Coastguard Worker rs = op(input, *sample.args, **sample.kwargs) 2086*da0073e9SAndroid Build Coastguard Worker # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761 2087*da0073e9SAndroid Build Coastguard Worker aten_name = op.aten_name if op.aten_name is not None else op.name 2088*da0073e9SAndroid Build Coastguard Worker opoverloadpacket = getattr(torch.ops.aten, aten_name, None) 2089*da0073e9SAndroid Build Coastguard Worker check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) 2090*da0073e9SAndroid Build Coastguard Worker 2091*da0073e9SAndroid Build Coastguard Worker 2092*da0073e9SAndroid Build Coastguard Workerclass TestSelfKwarg(TestCase): 2093*da0073e9SAndroid Build Coastguard Worker def test_self_kwargs(self): 2094*da0073e9SAndroid Build Coastguard Worker """Verify that we can call the aten ops with all kwargs even if the 2095*da0073e9SAndroid Build Coastguard Worker argument's name is "self" 2096*da0073e9SAndroid Build Coastguard Worker """ 2097*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.reshape.default(self=torch.rand(1, 2), shape=[2]) 2098*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.min.default(self=torch.rand(100)) 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 2102*da0073e9SAndroid Build Coastguard Workerclass TestRefsOpsInfo(TestCase): 2103*da0073e9SAndroid Build Coastguard Worker import_paths = [ 2104*da0073e9SAndroid Build Coastguard Worker "_refs", 2105*da0073e9SAndroid Build Coastguard Worker "_refs.special", 2106*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional", 2107*da0073e9SAndroid Build Coastguard Worker "_refs.fft", 2108*da0073e9SAndroid Build Coastguard Worker "_refs._conversions", 2109*da0073e9SAndroid Build Coastguard Worker ] 2110*da0073e9SAndroid Build Coastguard Worker module_alls = [ 2111*da0073e9SAndroid Build Coastguard Worker (path, import_module(f"torch.{path}").__all__) for path in import_paths 2112*da0073e9SAndroid Build Coastguard Worker ] 2113*da0073e9SAndroid Build Coastguard Worker ref_ops_names = tuple( 2114*da0073e9SAndroid Build Coastguard Worker itertools.chain.from_iterable( 2115*da0073e9SAndroid Build Coastguard Worker [f"{path}.{op}" for op in module_all] for path, module_all in module_alls 2116*da0073e9SAndroid Build Coastguard Worker ) 2117*da0073e9SAndroid Build Coastguard Worker ) 2118*da0073e9SAndroid Build Coastguard Worker ref_db_names = {ref_op.name for ref_op in python_ref_db} 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker # TODO: References that do not have an entry in python_ref_db 2121*da0073e9SAndroid Build Coastguard Worker skip_ref_ops = { 2122*da0073e9SAndroid Build Coastguard Worker "_refs.alias", 2123*da0073e9SAndroid Build Coastguard Worker "_refs.bitwise_right_shift", 2124*da0073e9SAndroid Build Coastguard Worker "_refs.copy_to", 2125*da0073e9SAndroid Build Coastguard Worker "_refs.empty_permuted", 2126*da0073e9SAndroid Build Coastguard Worker "_refs.empty_strided", 2127*da0073e9SAndroid Build Coastguard Worker "_refs.equal", 2128*da0073e9SAndroid Build Coastguard Worker "_refs.full", 2129*da0073e9SAndroid Build Coastguard Worker "_refs.full_like", 2130*da0073e9SAndroid Build Coastguard Worker "_refs.is_complex", 2131*da0073e9SAndroid Build Coastguard Worker "_refs.to", 2132*da0073e9SAndroid Build Coastguard Worker "_refs.mvlgamma", 2133*da0073e9SAndroid Build Coastguard Worker "_refs.ones", 2134*da0073e9SAndroid Build Coastguard Worker "_refs.ones_like", 2135*da0073e9SAndroid Build Coastguard Worker "_refs.special.expit", 2136*da0073e9SAndroid Build Coastguard Worker "_refs.std_var", 2137*da0073e9SAndroid Build Coastguard Worker "_refs.swap_axes", 2138*da0073e9SAndroid Build Coastguard Worker "_refs.uniform", 2139*da0073e9SAndroid Build Coastguard Worker "_refs.scalar_tensor", 2140*da0073e9SAndroid Build Coastguard Worker "_refs.trunc_divide", 2141*da0073e9SAndroid Build Coastguard Worker "_refs.zero", 2142*da0073e9SAndroid Build Coastguard Worker "_refs.zeros", 2143*da0073e9SAndroid Build Coastguard Worker "_refs.zeros_like", 2144*da0073e9SAndroid Build Coastguard Worker "_refs.rfloordiv", 2145*da0073e9SAndroid Build Coastguard Worker "_refs.rtruediv", 2146*da0073e9SAndroid Build Coastguard Worker "_refs.rpow", 2147*da0073e9SAndroid Build Coastguard Worker # These should be tested with their out-of-place counterparts 2148*da0073e9SAndroid Build Coastguard Worker "_refs.index_add_", 2149*da0073e9SAndroid Build Coastguard Worker "_refs.index_copy_", 2150*da0073e9SAndroid Build Coastguard Worker "_refs.index_fill_", 2151*da0073e9SAndroid Build Coastguard Worker "_refs.native_group_norm", 2152*da0073e9SAndroid Build Coastguard Worker } 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker not_in_decomp_table = { 2155*da0073e9SAndroid Build Coastguard Worker # duplicated in _decomp and _refs 2156*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.group_norm", 2157*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.mse_loss", 2158*da0073e9SAndroid Build Coastguard Worker "_refs.floor_divide", 2159*da0073e9SAndroid Build Coastguard Worker # duplicated as refs do not have decent support for advanced indexing 2160*da0073e9SAndroid Build Coastguard Worker "_refs.index_copy", 2161*da0073e9SAndroid Build Coastguard Worker "_refs.index_copy_", 2162*da0073e9SAndroid Build Coastguard Worker "_refs.index_add", 2163*da0073e9SAndroid Build Coastguard Worker "_refs.index_add_", 2164*da0073e9SAndroid Build Coastguard Worker # these are not aten ops? 2165*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.bfloat16", 2166*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.bool", 2167*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.byte", 2168*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.char", 2169*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.double", 2170*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.float", 2171*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.half", 2172*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.int", 2173*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.long", 2174*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.short", 2175*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.chalf", 2176*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.cfloat", 2177*da0073e9SAndroid Build Coastguard Worker "_refs._conversions.cdouble", 2178*da0073e9SAndroid Build Coastguard Worker "_refs.broadcast_shapes", 2179*da0073e9SAndroid Build Coastguard Worker "_refs.broadcast_tensors", 2180*da0073e9SAndroid Build Coastguard Worker "_refs.mvlgamma", 2181*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.layer_norm", 2182*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.tanhshrink", 2183*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.triplet_margin_loss", 2184*da0073e9SAndroid Build Coastguard Worker "_refs.rfloordiv", 2185*da0073e9SAndroid Build Coastguard Worker "_refs.rtruediv", 2186*da0073e9SAndroid Build Coastguard Worker "_refs.rpow", 2187*da0073e9SAndroid Build Coastguard Worker # CompositeImplicitAutograd 2188*da0073e9SAndroid Build Coastguard Worker "_refs.allclose", 2189*da0073e9SAndroid Build Coastguard Worker "_refs.atleast_1d", 2190*da0073e9SAndroid Build Coastguard Worker "_refs.atleast_2d", 2191*da0073e9SAndroid Build Coastguard Worker "_refs.atleast_3d", 2192*da0073e9SAndroid Build Coastguard Worker "_refs.broadcast_to", 2193*da0073e9SAndroid Build Coastguard Worker "_refs.chunk", 2194*da0073e9SAndroid Build Coastguard Worker "_refs.column_stack", 2195*da0073e9SAndroid Build Coastguard Worker "_refs.contiguous", 2196*da0073e9SAndroid Build Coastguard Worker "_refs.dsplit", 2197*da0073e9SAndroid Build Coastguard Worker "_refs.dstack", 2198*da0073e9SAndroid Build Coastguard Worker "_refs.fill", 2199*da0073e9SAndroid Build Coastguard Worker "_refs.fill_", 2200*da0073e9SAndroid Build Coastguard Worker "_refs.flatten", 2201*da0073e9SAndroid Build Coastguard Worker "_refs.fliplr", 2202*da0073e9SAndroid Build Coastguard Worker "_refs.flipud", 2203*da0073e9SAndroid Build Coastguard Worker "_refs.float_power", 2204*da0073e9SAndroid Build Coastguard Worker "_refs.hsplit", 2205*da0073e9SAndroid Build Coastguard Worker "_refs.hstack", 2206*da0073e9SAndroid Build Coastguard Worker "_refs.isclose", 2207*da0073e9SAndroid Build Coastguard Worker "_refs.isfinite", 2208*da0073e9SAndroid Build Coastguard Worker "_refs.isreal", 2209*da0073e9SAndroid Build Coastguard Worker "_refs.istft", 2210*da0073e9SAndroid Build Coastguard Worker "_refs.log_softmax", 2211*da0073e9SAndroid Build Coastguard Worker "_refs.movedim", 2212*da0073e9SAndroid Build Coastguard Worker "_refs.narrow", 2213*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.dropout", 2214*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.l1_loss", 2215*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.smooth_l1_loss", 2216*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.log_softmax", 2217*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.poisson_nll_loss", 2218*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.softmax", 2219*da0073e9SAndroid Build Coastguard Worker "_refs.nn.functional.softmin", 2220*da0073e9SAndroid Build Coastguard Worker "_refs.positive", 2221*da0073e9SAndroid Build Coastguard Worker "_refs.ravel", 2222*da0073e9SAndroid Build Coastguard Worker "_refs.reshape", 2223*da0073e9SAndroid Build Coastguard Worker "_refs.softmax", 2224*da0073e9SAndroid Build Coastguard Worker "_refs.special.expit", 2225*da0073e9SAndroid Build Coastguard Worker "_refs.special.log_softmax", 2226*da0073e9SAndroid Build Coastguard Worker "_refs.special.softmax", 2227*da0073e9SAndroid Build Coastguard Worker "_refs.square", 2228*da0073e9SAndroid Build Coastguard Worker "_refs.stft", 2229*da0073e9SAndroid Build Coastguard Worker "_refs.T", 2230*da0073e9SAndroid Build Coastguard Worker "_refs.take_along_dim", 2231*da0073e9SAndroid Build Coastguard Worker "_refs.tensor_split", 2232*da0073e9SAndroid Build Coastguard Worker "_refs.to", 2233*da0073e9SAndroid Build Coastguard Worker "_refs.true_divide", 2234*da0073e9SAndroid Build Coastguard Worker "_refs.trunc_divide", 2235*da0073e9SAndroid Build Coastguard Worker "_refs.vsplit", 2236*da0073e9SAndroid Build Coastguard Worker "_refs.vstack", 2237*da0073e9SAndroid Build Coastguard Worker "_refs.linalg.matrix_norm", 2238*da0073e9SAndroid Build Coastguard Worker "_refs.linalg.norm", 2239*da0073e9SAndroid Build Coastguard Worker "_refs.linalg.svd", 2240*da0073e9SAndroid Build Coastguard Worker "_refs.linalg.svdvals", 2241*da0073e9SAndroid Build Coastguard Worker "_refs.unflatten", 2242*da0073e9SAndroid Build Coastguard Worker "_refs.sum_to_size", 2243*da0073e9SAndroid Build Coastguard Worker # ref implementation missing kwargs 2244*da0073e9SAndroid Build Coastguard Worker "_refs.full_like", # missing "layout" 2245*da0073e9SAndroid Build Coastguard Worker "_refs.scalar_tensor", # missing "layout" 2246*da0073e9SAndroid Build Coastguard Worker # other 2247*da0073e9SAndroid Build Coastguard Worker "_refs.block_diag", # only refs._block_diag_iterable is in decomposition table 2248*da0073e9SAndroid Build Coastguard Worker "_refs.empty", # intentional; direct empty is faster and has less guards 2249*da0073e9SAndroid Build Coastguard Worker "_refs.empty_permuted", # intentional; direct empty is faster and has less guards 2250*da0073e9SAndroid Build Coastguard Worker "_refs.expand_as", 2251*da0073e9SAndroid Build Coastguard Worker "_refs.as_strided", # _prims._as_strided_meta: "reduce() of empty sequence with no initial value" 2252*da0073e9SAndroid Build Coastguard Worker "_refs.copy_to", # torch._C._jit_get_operation: No such operator aten::copy_to 2253*da0073e9SAndroid Build Coastguard Worker "_refs.equal", # 'bool' object has no attribute 'dtype' 2254*da0073e9SAndroid Build Coastguard Worker "_refs.conj", # Calls _prims.conj 2255*da0073e9SAndroid Build Coastguard Worker "_refs.real", 2256*da0073e9SAndroid Build Coastguard Worker "_refs.imag", 2257*da0073e9SAndroid Build Coastguard Worker "_refs.reshape_as", 2258*da0073e9SAndroid Build Coastguard Worker "_refs.view_as", 2259*da0073e9SAndroid Build Coastguard Worker "_refs.view_as_complex", # TorchInductor does not support complex at the moment. 2260*da0073e9SAndroid Build Coastguard Worker # the decompositions for these ops are slightly different 2261*da0073e9SAndroid Build Coastguard Worker # because of out handling 2262*da0073e9SAndroid Build Coastguard Worker "_refs.var_mean", 2263*da0073e9SAndroid Build Coastguard Worker "_refs.std_mean", 2264*da0073e9SAndroid Build Coastguard Worker "_refs.native_layer_norm", 2265*da0073e9SAndroid Build Coastguard Worker } 2266*da0073e9SAndroid Build Coastguard Worker 2267*da0073e9SAndroid Build Coastguard Worker @parametrize("op", ref_ops_names) 2268*da0073e9SAndroid Build Coastguard Worker def test_refs_are_in_python_ref_db(self, op): 2269*da0073e9SAndroid Build Coastguard Worker inplace = op[-1] == "_" 2270*da0073e9SAndroid Build Coastguard Worker if op in self.skip_ref_ops: 2271*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db") 2272*da0073e9SAndroid Build Coastguard Worker elif inplace: 2273*da0073e9SAndroid Build Coastguard Worker self.assertNotIn( 2274*da0073e9SAndroid Build Coastguard Worker op, 2275*da0073e9SAndroid Build Coastguard Worker self.ref_db_names, 2276*da0073e9SAndroid Build Coastguard Worker msg=f"{op} is an in-place operation and should not have an OpInfo", 2277*da0073e9SAndroid Build Coastguard Worker ) 2278*da0073e9SAndroid Build Coastguard Worker else: 2279*da0073e9SAndroid Build Coastguard Worker # Intentionally don't use assertIn to avoid printing the 2280*da0073e9SAndroid Build Coastguard Worker # (very large) container 2281*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op in self.ref_db_names, msg=f"{op} not in ref_db_names") 2282*da0073e9SAndroid Build Coastguard Worker 2283*da0073e9SAndroid Build Coastguard Worker @parametrize("op", ref_ops_names) 2284*da0073e9SAndroid Build Coastguard Worker def test_refs_are_in_decomp_table(self, op): 2285*da0073e9SAndroid Build Coastguard Worker path = op.split(".") 2286*da0073e9SAndroid Build Coastguard Worker module_path = ".".join(path[:-1]) 2287*da0073e9SAndroid Build Coastguard Worker op_name = path[-1] 2288*da0073e9SAndroid Build Coastguard Worker op_impl = getattr(import_module(f"torch.{module_path}"), op_name) 2289*da0073e9SAndroid Build Coastguard Worker 2290*da0073e9SAndroid Build Coastguard Worker if op in self.not_in_decomp_table: 2291*da0073e9SAndroid Build Coastguard Worker self.assertNotIn( 2292*da0073e9SAndroid Build Coastguard Worker op_impl, 2293*da0073e9SAndroid Build Coastguard Worker torch._decomp.decomposition_table.values(), 2294*da0073e9SAndroid Build Coastguard Worker f"Unexpectedly found {op} in torch._decomp.decomposition_table.values()", 2295*da0073e9SAndroid Build Coastguard Worker ) 2296*da0073e9SAndroid Build Coastguard Worker else: 2297*da0073e9SAndroid Build Coastguard Worker self.assertIn( 2298*da0073e9SAndroid Build Coastguard Worker op_impl, 2299*da0073e9SAndroid Build Coastguard Worker torch._decomp.decomposition_table.values(), 2300*da0073e9SAndroid Build Coastguard Worker f"Did not find {op} in torch._decomp.decomposition_table.values()", 2301*da0073e9SAndroid Build Coastguard Worker ) 2302*da0073e9SAndroid Build Coastguard Worker 2303*da0073e9SAndroid Build Coastguard Worker 2304*da0073e9SAndroid Build Coastguard Workerfake_skips = ( 2305*da0073e9SAndroid Build Coastguard Worker "aminmax", # failing input 2306*da0073e9SAndroid Build Coastguard Worker "cov", # aweights cannot be negtaive 2307*da0073e9SAndroid Build Coastguard Worker "istft", # window overlap add min: 0 2308*da0073e9SAndroid Build Coastguard Worker "linalg.eigvals", # The tensor has a non-zero number of elements, but its data is not allocated yet 2309*da0073e9SAndroid Build Coastguard Worker "linalg.eigvalsh", # aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend 2310*da0073e9SAndroid Build Coastguard Worker "linalg.matrix_power", # Could not run 'aten::eye.m_out' with arguments from the 'Meta' backend 2311*da0073e9SAndroid Build Coastguard Worker # "linalg.pinv", # Could not run 'aten::pinv.out' with arguments from the 'Meta' backen 2312*da0073e9SAndroid Build Coastguard Worker "linalg.matrix_rank.hermitian", # Could not run 'aten::linalg_eigvalsh.out' with arguments from the 'Meta' backend 2313*da0073e9SAndroid Build Coastguard Worker "linalg.pinv.hermitian", # tensor.mH is only supported on matrices or batches of matrices. Got 1-D tensor 2314*da0073e9SAndroid Build Coastguard Worker "linalg.solve", # Could not run 'aten::linalg_solve' with arguments from the 'Meta' backend 2315*da0073e9SAndroid Build Coastguard Worker "linalg.tensorsolve", # Could not run 'aten::linalg_solve' with arguments from the 'Meta' 2316*da0073e9SAndroid Build Coastguard Worker "lu_solve", # MALLOC ERROR: debug 2317*da0073e9SAndroid Build Coastguard Worker "multinomial", # Could not run 'aten::multinomial' with arguments from the 'Meta' backend 2318*da0073e9SAndroid Build Coastguard Worker "mvlgamma.mvlgamma_p_1", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend 2319*da0073e9SAndroid Build Coastguard Worker "mvlgamma.mvlgamma_p_3", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend 2320*da0073e9SAndroid Build Coastguard Worker "mvlgamma.mvlgamma_p_5", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend 2321*da0073e9SAndroid Build Coastguard Worker "nanmean", # logical_not() got an unexpected keyword argument 'out' 2322*da0073e9SAndroid Build Coastguard Worker "quantile", # quantile() q values must be in the range [0, 1] 2323*da0073e9SAndroid Build Coastguard Worker "nanquantile", # quantile() q values must be in the range [0, 1] 2324*da0073e9SAndroid Build Coastguard Worker "nn.functional.ctc_loss", # The tensor has a non-zero number of elements, but its data is not allocated yet 2325*da0073e9SAndroid Build Coastguard Worker "nn.functional.embedding_bag", # sometimes errors 2326*da0073e9SAndroid Build Coastguard Worker "nn.functional.nll_loss", # sometimes errors 2327*da0073e9SAndroid Build Coastguard Worker "nn.functional.max_pool1d", # The tensor has a non-zero number of elements 2328*da0073e9SAndroid Build Coastguard Worker "to_sparse", # Could not run 'aten::_to_sparse' with arguments from the 'Meta' backend 2329*da0073e9SAndroid Build Coastguard Worker "tensor_split", # The tensor has a non-zero number of elements, but its data is not allocated yet 2330*da0073e9SAndroid Build Coastguard Worker "repeat_interleave", # cannot repeat_interleave a meta tensor without output_size 2331*da0073e9SAndroid Build Coastguard Worker "sparse.sampled.addmm", # sparsity not supported 2332*da0073e9SAndroid Build Coastguard Worker # Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException 2333*da0073e9SAndroid Build Coastguard Worker "nn.functional.one_hot", 2334*da0073e9SAndroid Build Coastguard Worker "narrow", # Fails only for one overload with DataDependentOutputException (hence skip). 2335*da0073e9SAndroid Build Coastguard Worker) 2336*da0073e9SAndroid Build Coastguard Worker 2337*da0073e9SAndroid Build Coastguard Workerfake_autocast_device_skips = defaultdict(dict) 2338*da0073e9SAndroid Build Coastguard Worker 2339*da0073e9SAndroid Build Coastguard Worker# TODO: investigate/fix 2340*da0073e9SAndroid Build Coastguard Workerfake_autocast_device_skips["cpu"] = {"linalg.pinv"} 2341*da0073e9SAndroid Build Coastguard Workerfake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"} 2342*da0073e9SAndroid Build Coastguard Worker 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Workerdynamic_output_op_tests = ( 2345*da0073e9SAndroid Build Coastguard Worker "argwhere", 2346*da0073e9SAndroid Build Coastguard Worker "bincount", 2347*da0073e9SAndroid Build Coastguard Worker "combinations", 2348*da0073e9SAndroid Build Coastguard Worker "linalg.lstsq", 2349*da0073e9SAndroid Build Coastguard Worker "masked_select", 2350*da0073e9SAndroid Build Coastguard Worker "nonzero", 2351*da0073e9SAndroid Build Coastguard Worker "unique_consecutive", 2352*da0073e9SAndroid Build Coastguard Worker "unique", 2353*da0073e9SAndroid Build Coastguard Worker "linalg.lstsq.grad_oriented", 2354*da0073e9SAndroid Build Coastguard Worker) 2355*da0073e9SAndroid Build Coastguard Worker 2356*da0073e9SAndroid Build Coastguard Worker# Ops that have dynamic output shapes that we can handle when 2357*da0073e9SAndroid Build Coastguard Worker# allow_dynamic_shape_ops is True in fake tensor shape environment. 2358*da0073e9SAndroid Build Coastguard Workersupported_dynamic_output_op_tests = ( 2359*da0073e9SAndroid Build Coastguard Worker "nonzero", 2360*da0073e9SAndroid Build Coastguard Worker "unique", 2361*da0073e9SAndroid Build Coastguard Worker "repeat_interleave", 2362*da0073e9SAndroid Build Coastguard Worker "masked_select", 2363*da0073e9SAndroid Build Coastguard Worker) 2364*da0073e9SAndroid Build Coastguard Worker 2365*da0073e9SAndroid Build Coastguard Worker# some inputs invoke dynamic output shape operators, some do not 2366*da0073e9SAndroid Build Coastguard Workersometimes_dynamic_output_op_test = ("__getitem__", "index_select") 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Workerdata_dependent_op_tests = ( 2369*da0073e9SAndroid Build Coastguard Worker "equal", 2370*da0073e9SAndroid Build Coastguard Worker "corrcoef", 2371*da0073e9SAndroid Build Coastguard Worker "nn.functional.gaussian_nll_loss", 2372*da0073e9SAndroid Build Coastguard Worker "allclose", 2373*da0073e9SAndroid Build Coastguard Worker) 2374*da0073e9SAndroid Build Coastguard Worker 2375*da0073e9SAndroid Build Coastguard Workeraliasing_failures = ("histogramdd",) 2376*da0073e9SAndroid Build Coastguard Worker 2377*da0073e9SAndroid Build Coastguard Workerfake_backward_skips = { 2378*da0073e9SAndroid Build Coastguard Worker "linalg.cond", 2379*da0073e9SAndroid Build Coastguard Worker "linalg.matrix_norm", 2380*da0073e9SAndroid Build Coastguard Worker "linalg.norm", 2381*da0073e9SAndroid Build Coastguard Worker "linalg.svd", 2382*da0073e9SAndroid Build Coastguard Worker "linalg.svdvals", 2383*da0073e9SAndroid Build Coastguard Worker "pca_lowrank", 2384*da0073e9SAndroid Build Coastguard Worker "roll", 2385*da0073e9SAndroid Build Coastguard Worker "svd_lowrank", 2386*da0073e9SAndroid Build Coastguard Worker "sgn", 2387*da0073e9SAndroid Build Coastguard Worker} 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Workerfake_backward_xfails = {skip(s) for s in fake_backward_skips} | { 2390*da0073e9SAndroid Build Coastguard Worker xfail("fft.ihfftn"), # Mismatch in aten._conj_physical.default 2391*da0073e9SAndroid Build Coastguard Worker xfail("fft.ihfft2"), # Mismatch in aten._conj_physical.default 2392*da0073e9SAndroid Build Coastguard Worker skip("nn.functional.ctc_loss"), 2393*da0073e9SAndroid Build Coastguard Worker} 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Workerfake_autocast_backward_xfails = { 2396*da0073e9SAndroid Build Coastguard Worker skip("nn.functional.binary_cross_entropy"), 2397*da0073e9SAndroid Build Coastguard Worker skip("sparse.sampled_addmm"), 2398*da0073e9SAndroid Build Coastguard Worker skip("linalg.pinv"), 2399*da0073e9SAndroid Build Coastguard Worker skip("linalg.pinv", "hermitian"), 2400*da0073e9SAndroid Build Coastguard Worker skip("linalg.pinv", "singular"), 2401*da0073e9SAndroid Build Coastguard Worker skip("pinverse"), 2402*da0073e9SAndroid Build Coastguard Worker} 2403*da0073e9SAndroid Build Coastguard Worker 2404*da0073e9SAndroid Build Coastguard Worker 2405*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 2406*da0073e9SAndroid Build Coastguard Workerclass TestFakeTensor(TestCase): 2407*da0073e9SAndroid Build Coastguard Worker def setUp(self): 2408*da0073e9SAndroid Build Coastguard Worker # Turn on FakeTensor caching and cross-checking for these tests: 2409*da0073e9SAndroid Build Coastguard Worker cache_enabled = unittest.mock.patch( 2410*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.fake_tensor_cache_enabled", True 2411*da0073e9SAndroid Build Coastguard Worker ) 2412*da0073e9SAndroid Build Coastguard Worker cache_enabled.start() 2413*da0073e9SAndroid Build Coastguard Worker self.addCleanup(cache_enabled.stop) 2414*da0073e9SAndroid Build Coastguard Worker 2415*da0073e9SAndroid Build Coastguard Worker cache_crosscheck = unittest.mock.patch( 2416*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.fake_tensor_cache_crosscheck_enabled", True 2417*da0073e9SAndroid Build Coastguard Worker ) 2418*da0073e9SAndroid Build Coastguard Worker cache_crosscheck.start() 2419*da0073e9SAndroid Build Coastguard Worker self.addCleanup(cache_crosscheck.stop) 2420*da0073e9SAndroid Build Coastguard Worker 2421*da0073e9SAndroid Build Coastguard Worker def _test_fake_helper(self, device, dtype, op, context): 2422*da0073e9SAndroid Build Coastguard Worker name = op.name 2423*da0073e9SAndroid Build Coastguard Worker if op.variant_test_name: 2424*da0073e9SAndroid Build Coastguard Worker name += "." + op.variant_test_name 2425*da0073e9SAndroid Build Coastguard Worker if name in fake_skips or "sparse" in name or "jiterator" in name: 2426*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skip failing test") 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 2429*da0073e9SAndroid Build Coastguard Worker for sample in samples: 2430*da0073e9SAndroid Build Coastguard Worker mode = FakeTensorMode() 2431*da0073e9SAndroid Build Coastguard Worker 2432*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.symbolic_shapes import ShapeEnv 2433*da0073e9SAndroid Build Coastguard Worker 2434*da0073e9SAndroid Build Coastguard Worker allow_dynamic_output_shape_shape_env = ShapeEnv( 2435*da0073e9SAndroid Build Coastguard Worker allow_dynamic_output_shape_ops=True 2436*da0073e9SAndroid Build Coastguard Worker ) 2437*da0073e9SAndroid Build Coastguard Worker 2438*da0073e9SAndroid Build Coastguard Worker allow_dynamic_output_shape_mode = FakeTensorMode( 2439*da0073e9SAndroid Build Coastguard Worker shape_env=allow_dynamic_output_shape_shape_env 2440*da0073e9SAndroid Build Coastguard Worker ) 2441*da0073e9SAndroid Build Coastguard Worker 2442*da0073e9SAndroid Build Coastguard Worker try: 2443*da0073e9SAndroid Build Coastguard Worker with context(): 2444*da0073e9SAndroid Build Coastguard Worker res = op(sample.input, *sample.args, **sample.kwargs) 2445*da0073e9SAndroid Build Coastguard Worker except Exception: 2446*da0073e9SAndroid Build Coastguard Worker continue 2447*da0073e9SAndroid Build Coastguard Worker 2448*da0073e9SAndroid Build Coastguard Worker def run_with_fake_mode_and_verify(fake_mode, match_results=True): 2449*da0073e9SAndroid Build Coastguard Worker def map_to_fake(e): 2450*da0073e9SAndroid Build Coastguard Worker if isinstance(e, torch.Tensor): 2451*da0073e9SAndroid Build Coastguard Worker return fake_mode.from_tensor(e) 2452*da0073e9SAndroid Build Coastguard Worker else: 2453*da0073e9SAndroid Build Coastguard Worker return e 2454*da0073e9SAndroid Build Coastguard Worker 2455*da0073e9SAndroid Build Coastguard Worker input = tree_map(map_to_fake, sample.input) 2456*da0073e9SAndroid Build Coastguard Worker args = tree_map(map_to_fake, sample.args) 2457*da0073e9SAndroid Build Coastguard Worker kwargs = tree_map(map_to_fake, sample.kwargs) 2458*da0073e9SAndroid Build Coastguard Worker 2459*da0073e9SAndroid Build Coastguard Worker try: 2460*da0073e9SAndroid Build Coastguard Worker with context(): 2461*da0073e9SAndroid Build Coastguard Worker with fake_mode: 2462*da0073e9SAndroid Build Coastguard Worker res_fake = op(input, *args, **kwargs) 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker if not match_results: 2465*da0073e9SAndroid Build Coastguard Worker return 2466*da0073e9SAndroid Build Coastguard Worker 2467*da0073e9SAndroid Build Coastguard Worker for fake_out, real_out in zip( 2468*da0073e9SAndroid Build Coastguard Worker pytree.tree_leaves(res_fake), pytree.tree_leaves(res) 2469*da0073e9SAndroid Build Coastguard Worker ): 2470*da0073e9SAndroid Build Coastguard Worker if not isinstance(fake_out, torch.Tensor): 2471*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not isinstance(real_out, torch.Tensor)) 2472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fake_out, real_out) 2473*da0073e9SAndroid Build Coastguard Worker continue 2474*da0073e9SAndroid Build Coastguard Worker 2475*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(fake_out, FakeTensor)) 2476*da0073e9SAndroid Build Coastguard Worker # if you see a shape exception here, you may need to add 2477*da0073e9SAndroid Build Coastguard Worker # a `dynamic_output_shape` tag to an operator 2478*da0073e9SAndroid Build Coastguard Worker 2479*da0073e9SAndroid Build Coastguard Worker if op.op not in [ 2480*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._efficient_attention_forward, 2481*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._flash_attention_forward, 2482*da0073e9SAndroid Build Coastguard Worker ]: 2483*da0073e9SAndroid Build Coastguard Worker # prims/decomps must correctly model strides, 2484*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325 2485*da0073e9SAndroid Build Coastguard Worker 2486*da0073e9SAndroid Build Coastguard Worker # note: the excluded ops have intentionally incorrect device; 2487*da0073e9SAndroid Build Coastguard Worker # see "Note [Seed and Offset]" (_meta_registrations.py) 2488*da0073e9SAndroid Build Coastguard Worker prims.utils.compare_tensor_meta(fake_out, real_out, True) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker if name not in aliasing_failures: 2491*da0073e9SAndroid Build Coastguard Worker fake_aliasing = outputs_alias_inputs( 2492*da0073e9SAndroid Build Coastguard Worker (input, args, kwargs), res_fake 2493*da0073e9SAndroid Build Coastguard Worker ) 2494*da0073e9SAndroid Build Coastguard Worker real_aliasing = outputs_alias_inputs( 2495*da0073e9SAndroid Build Coastguard Worker (sample.input, sample, args, sample.kwargs), res 2496*da0073e9SAndroid Build Coastguard Worker ) 2497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fake_aliasing, real_aliasing) 2498*da0073e9SAndroid Build Coastguard Worker 2499*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2500*da0073e9SAndroid Build Coastguard Worker name not in dynamic_output_op_tests 2501*da0073e9SAndroid Build Coastguard Worker and name not in data_dependent_op_tests 2502*da0073e9SAndroid Build Coastguard Worker ) 2503*da0073e9SAndroid Build Coastguard Worker 2504*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: 2505*da0073e9SAndroid Build Coastguard Worker pass 2506*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.UnsupportedOperatorException: 2507*da0073e9SAndroid Build Coastguard Worker pass 2508*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.DynamicOutputShapeException: 2509*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2510*da0073e9SAndroid Build Coastguard Worker name in dynamic_output_op_tests 2511*da0073e9SAndroid Build Coastguard Worker or name in sometimes_dynamic_output_op_test 2512*da0073e9SAndroid Build Coastguard Worker ) 2513*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2514*da0073e9SAndroid Build Coastguard Worker fake_mode.shape_env is None 2515*da0073e9SAndroid Build Coastguard Worker or not fake_mode.shape_env.allow_dynamic_output_shape_ops 2516*da0073e9SAndroid Build Coastguard Worker or name not in supported_dynamic_output_op_tests 2517*da0073e9SAndroid Build Coastguard Worker ) 2518*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.DataDependentOutputException: 2519*da0073e9SAndroid Build Coastguard Worker self.assertTrue(name in data_dependent_op_tests) 2520*da0073e9SAndroid Build Coastguard Worker 2521*da0073e9SAndroid Build Coastguard Worker run_with_fake_mode_and_verify(mode) 2522*da0073e9SAndroid Build Coastguard Worker if name in supported_dynamic_output_op_tests: 2523*da0073e9SAndroid Build Coastguard Worker run_with_fake_mode_and_verify( 2524*da0073e9SAndroid Build Coastguard Worker allow_dynamic_output_shape_mode, match_results=False 2525*da0073e9SAndroid Build Coastguard Worker ) 2526*da0073e9SAndroid Build Coastguard Worker 2527*da0073e9SAndroid Build Coastguard Worker @ops(op_db, dtypes=OpDTypes.any_one) 2528*da0073e9SAndroid Build Coastguard Worker def test_pointwise_ops(self, device, dtype, op): 2529*da0073e9SAndroid Build Coastguard Worker name = op.name 2530*da0073e9SAndroid Build Coastguard Worker if op.variant_test_name: 2531*da0073e9SAndroid Build Coastguard Worker name += "." + op.variant_test_name 2532*da0073e9SAndroid Build Coastguard Worker if name in fake_skips or "sparse" in name or "jiterator" in name: 2533*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skip failing test") 2534*da0073e9SAndroid Build Coastguard Worker 2535*da0073e9SAndroid Build Coastguard Worker test_self = self 2536*da0073e9SAndroid Build Coastguard Worker 2537*da0073e9SAndroid Build Coastguard Worker class TestPointwiseMode(TorchDispatchMode): 2538*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2539*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs or {} 2540*da0073e9SAndroid Build Coastguard Worker 2541*da0073e9SAndroid Build Coastguard Worker out = func(*args, **kwargs) 2542*da0073e9SAndroid Build Coastguard Worker 2543*da0073e9SAndroid Build Coastguard Worker if torch.Tag.pointwise in func.tags: 2544*da0073e9SAndroid Build Coastguard Worker shapes = [] 2545*da0073e9SAndroid Build Coastguard Worker for inp in pytree.arg_tree_leaves(*args, **kwargs): 2546*da0073e9SAndroid Build Coastguard Worker if isinstance(inp, torch.Tensor): 2547*da0073e9SAndroid Build Coastguard Worker shapes.append(inp.shape) 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker out_shape = torch._refs._broadcast_shapes(*shapes) 2550*da0073e9SAndroid Build Coastguard Worker 2551*da0073e9SAndroid Build Coastguard Worker for out_elem in pytree.tree_leaves(out): 2552*da0073e9SAndroid Build Coastguard Worker if isinstance(out_elem, torch.Tensor): 2553*da0073e9SAndroid Build Coastguard Worker test_self.assertEqual(out_elem.shape, out_shape) 2554*da0073e9SAndroid Build Coastguard Worker 2555*da0073e9SAndroid Build Coastguard Worker return out 2556*da0073e9SAndroid Build Coastguard Worker 2557*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=False) 2558*da0073e9SAndroid Build Coastguard Worker for sample in samples: 2559*da0073e9SAndroid Build Coastguard Worker mode = FakeTensorMode() 2560*da0073e9SAndroid Build Coastguard Worker 2561*da0073e9SAndroid Build Coastguard Worker def map_to_fake(e): 2562*da0073e9SAndroid Build Coastguard Worker if isinstance(e, torch.Tensor): 2563*da0073e9SAndroid Build Coastguard Worker return mode.from_tensor(e) 2564*da0073e9SAndroid Build Coastguard Worker else: 2565*da0073e9SAndroid Build Coastguard Worker return e 2566*da0073e9SAndroid Build Coastguard Worker 2567*da0073e9SAndroid Build Coastguard Worker input = tree_map(map_to_fake, sample.input) 2568*da0073e9SAndroid Build Coastguard Worker args = tree_map(map_to_fake, sample.args) 2569*da0073e9SAndroid Build Coastguard Worker kwargs = tree_map(map_to_fake, sample.kwargs) 2570*da0073e9SAndroid Build Coastguard Worker 2571*da0073e9SAndroid Build Coastguard Worker try: 2572*da0073e9SAndroid Build Coastguard Worker op(input, *args, **kwargs) 2573*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2574*da0073e9SAndroid Build Coastguard Worker continue 2575*da0073e9SAndroid Build Coastguard Worker 2576*da0073e9SAndroid Build Coastguard Worker with TestPointwiseMode(): 2577*da0073e9SAndroid Build Coastguard Worker with mode: 2578*da0073e9SAndroid Build Coastguard Worker op(input, *args, **kwargs) 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker @ops(op_db, dtypes=OpDTypes.any_one) 2581*da0073e9SAndroid Build Coastguard Worker def test_fake(self, device, dtype, op): 2582*da0073e9SAndroid Build Coastguard Worker self._test_fake_helper(device, dtype, op, contextlib.nullcontext) 2583*da0073e9SAndroid Build Coastguard Worker 2584*da0073e9SAndroid Build Coastguard Worker @ops(op_db, dtypes=OpDTypes.any_one) 2585*da0073e9SAndroid Build Coastguard Worker def test_fake_autocast(self, device, dtype, op): 2586*da0073e9SAndroid Build Coastguard Worker device_type = torch.device(device).type 2587*da0073e9SAndroid Build Coastguard Worker if op.name in fake_autocast_device_skips[device_type]: 2588*da0073e9SAndroid Build Coastguard Worker self.skipTest("Skip failing test") 2589*da0073e9SAndroid Build Coastguard Worker 2590*da0073e9SAndroid Build Coastguard Worker def context_fn(): 2591*da0073e9SAndroid Build Coastguard Worker return torch.amp.autocast(device_type) 2592*da0073e9SAndroid Build Coastguard Worker 2593*da0073e9SAndroid Build Coastguard Worker self._test_fake_helper(device, dtype, op, context_fn) 2594*da0073e9SAndroid Build Coastguard Worker 2595*da0073e9SAndroid Build Coastguard Worker def _test_fake_crossref_helper(self, device, dtype, op, context): 2596*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=True) 2597*da0073e9SAndroid Build Coastguard Worker 2598*da0073e9SAndroid Build Coastguard Worker for iter, sample in enumerate(samples): 2599*da0073e9SAndroid Build Coastguard Worker args = [sample.input] + list(sample.args) 2600*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs 2601*da0073e9SAndroid Build Coastguard Worker 2602*da0073e9SAndroid Build Coastguard Worker # skip these to speed up tests 2603*da0073e9SAndroid Build Coastguard Worker common_skip_ops = ( 2604*da0073e9SAndroid Build Coastguard Worker aten.detach.default, 2605*da0073e9SAndroid Build Coastguard Worker aten.empty_strided.default, 2606*da0073e9SAndroid Build Coastguard Worker aten.copy_.default, 2607*da0073e9SAndroid Build Coastguard Worker aten.is_same_size.default, 2608*da0073e9SAndroid Build Coastguard Worker ) 2609*da0073e9SAndroid Build Coastguard Worker 2610*da0073e9SAndroid Build Coastguard Worker # TODO: enable check_aliasing, batch norm fails 2611*da0073e9SAndroid Build Coastguard Worker try: 2612*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.CrossRefFakeMode( 2613*da0073e9SAndroid Build Coastguard Worker ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True 2614*da0073e9SAndroid Build Coastguard Worker ): 2615*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled( 2616*da0073e9SAndroid Build Coastguard Worker False 2617*da0073e9SAndroid Build Coastguard Worker ): 2618*da0073e9SAndroid Build Coastguard Worker composite_compliance.compute_expected_grads( 2619*da0073e9SAndroid Build Coastguard Worker op.get_op(), 2620*da0073e9SAndroid Build Coastguard Worker args, 2621*da0073e9SAndroid Build Coastguard Worker kwargs, 2622*da0073e9SAndroid Build Coastguard Worker sample.output_process_fn_grad, 2623*da0073e9SAndroid Build Coastguard Worker op.gradcheck_wrapper, 2624*da0073e9SAndroid Build Coastguard Worker ) 2625*da0073e9SAndroid Build Coastguard Worker except torch._subclasses.fake_tensor.UnsupportedOperatorException: 2626*da0073e9SAndroid Build Coastguard Worker pass 2627*da0073e9SAndroid Build Coastguard Worker 2628*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2629*da0073e9SAndroid Build Coastguard Worker @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) 2630*da0073e9SAndroid Build Coastguard Worker @skipOps( 2631*da0073e9SAndroid Build Coastguard Worker "TestFakeTensor", "test_fake_crossref_backward_no_amp", fake_backward_xfails 2632*da0073e9SAndroid Build Coastguard Worker ) 2633*da0073e9SAndroid Build Coastguard Worker def test_fake_crossref_backward_no_amp(self, device, dtype, op): 2634*da0073e9SAndroid Build Coastguard Worker self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext) 2635*da0073e9SAndroid Build Coastguard Worker 2636*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2637*da0073e9SAndroid Build Coastguard Worker @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) 2638*da0073e9SAndroid Build Coastguard Worker @skipOps( 2639*da0073e9SAndroid Build Coastguard Worker "TestFakeTensor", 2640*da0073e9SAndroid Build Coastguard Worker "test_fake_crossref_backward_amp", 2641*da0073e9SAndroid Build Coastguard Worker fake_backward_xfails | fake_autocast_backward_xfails, 2642*da0073e9SAndroid Build Coastguard Worker ) 2643*da0073e9SAndroid Build Coastguard Worker def test_fake_crossref_backward_amp(self, device, dtype, op): 2644*da0073e9SAndroid Build Coastguard Worker self._test_fake_crossref_helper(device, dtype, op, torch.cuda.amp.autocast) 2645*da0073e9SAndroid Build Coastguard Worker 2646*da0073e9SAndroid Build Coastguard Worker @ops([op for op in ops_and_refs if op.is_factory_function]) 2647*da0073e9SAndroid Build Coastguard Worker def test_strided_layout(self, device, dtype, op): 2648*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 2649*da0073e9SAndroid Build Coastguard Worker for sample in samples: 2650*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs.copy() 2651*da0073e9SAndroid Build Coastguard Worker kwargs["layout"] = torch.strided 2652*da0073e9SAndroid Build Coastguard Worker strided_result = op(sample.input, *sample.args, **kwargs) 2653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_result.layout, torch.strided) 2654*da0073e9SAndroid Build Coastguard Worker 2655*da0073e9SAndroid Build Coastguard Worker 2656*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCommon, globals()) 2657*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCompositeCompliance, globals()) 2658*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMathBits, globals()) 2659*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") 2660*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFakeTensor, globals()) 2661*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTags, globals()) 2662*da0073e9SAndroid Build Coastguard Worker 2663*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2664*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 2665*da0073e9SAndroid Build Coastguard Worker run_tests() 2666