1import functools 2import unittest 3from unittest.mock import patch 4 5import torch 6 7 8aten = torch.ops.aten 9 10# This list is not meant to be comprehensive 11_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [ 12 aten.arctan2.default, 13 aten.divide.Tensor, 14 aten.divide.Scalar, 15 aten.divide.Tensor_mode, 16 aten.divide.Scalar_mode, 17 aten.multiply.Tensor, 18 aten.multiply.Scalar, 19 aten.subtract.Tensor, 20 aten.subtract.Scalar, 21 aten.true_divide.Tensor, 22 aten.true_divide.Scalar, 23 aten.greater.Tensor, 24 aten.greater.Scalar, 25 aten.greater_equal.Tensor, 26 aten.greater_equal.Scalar, 27 aten.less_equal.Tensor, 28 aten.less_equal.Scalar, 29 aten.less.Tensor, 30 aten.less.Scalar, 31 aten.not_equal.Tensor, 32 aten.not_equal.Scalar, 33 aten.cat.names, 34 aten.sum.dim_DimnameList, 35 aten.mean.names_dim, 36 aten.prod.dim_Dimname, 37 aten.all.dimname, 38 aten.norm.names_ScalarOpt_dim, 39 aten.norm.names_ScalarOpt_dim_dtype, 40 aten.var.default, 41 aten.var.dim, 42 aten.var.names_dim, 43 aten.var.correction_names, 44 aten.std.default, 45 aten.std.dim, 46 aten.std.names_dim, 47 aten.std.correction_names, 48 aten.absolute.default, 49 aten.arccos.default, 50 aten.arccosh.default, 51 aten.arcsin.default, 52 aten.arcsinh.default, 53 aten.arctan.default, 54 aten.arctanh.default, 55 aten.clip.default, 56 aten.clip.Tensor, 57 aten.fix.default, 58 aten.negative.default, 59 aten.square.default, 60 aten.size.int, 61 aten.size.Dimname, 62 aten.stride.int, 63 aten.stride.Dimname, 64 aten.repeat_interleave.self_Tensor, 65 aten.repeat_interleave.self_int, 66 aten.sym_size.int, 67 aten.sym_stride.int, 68 aten.atleast_1d.Sequence, 69 aten.atleast_2d.Sequence, 70 aten.atleast_3d.Sequence, 71 aten.linear.default, 72 aten.conv2d.default, 73 aten.conv2d.padding, 74 aten.mish_backward.default, 75 aten.silu_backward.default, 76 aten.index_add.dimname, 77 aten.pad_sequence.default, 78 aten.index_copy.dimname, 79 aten.upsample_nearest1d.vec, 80 aten.upsample_nearest2d.vec, 81 aten.upsample_nearest3d.vec, 82 aten._upsample_nearest_exact1d.vec, 83 aten._upsample_nearest_exact2d.vec, 84 aten._upsample_nearest_exact3d.vec, 85 aten.rnn_tanh.input, 86 aten.rnn_tanh.data, 87 aten.rnn_relu.input, 88 aten.rnn_relu.data, 89 aten.lstm.input, 90 aten.lstm.data, 91 aten.gru.input, 92 aten.gru.data, 93 aten._upsample_bilinear2d_aa.vec, 94 aten._upsample_bicubic2d_aa.vec, 95 aten.upsample_bilinear2d.vec, 96 aten.upsample_trilinear3d.vec, 97 aten.upsample_linear1d.vec, 98 aten.matmul.default, 99 aten.upsample_bicubic2d.vec, 100 aten.__and__.Scalar, 101 aten.__and__.Tensor, 102 aten.__or__.Tensor, 103 aten.__or__.Scalar, 104 aten.__xor__.Tensor, 105 aten.__xor__.Scalar, 106 aten.scatter.dimname_src, 107 aten.scatter.dimname_value, 108 aten.scatter_add.dimname, 109 aten.is_complex.default, 110 aten.logsumexp.names, 111 aten.where.ScalarOther, 112 aten.where.ScalarSelf, 113 aten.where.Scalar, 114 aten.where.default, 115 aten.item.default, 116 aten.any.dimname, 117 aten.std_mean.default, 118 aten.std_mean.dim, 119 aten.std_mean.names_dim, 120 aten.std_mean.correction_names, 121 aten.var_mean.default, 122 aten.var_mean.dim, 123 aten.var_mean.names_dim, 124 aten.var_mean.correction_names, 125 aten.broadcast_tensors.default, 126 aten.stft.default, 127 aten.stft.center, 128 aten.istft.default, 129 aten.index_fill.Dimname_Scalar, 130 aten.index_fill.Dimname_Tensor, 131 aten.index_select.dimname, 132 aten.diag.default, 133 aten.cumsum.dimname, 134 aten.cumprod.dimname, 135 aten.meshgrid.default, 136 aten.meshgrid.indexing, 137 aten.fft_fft.default, 138 aten.fft_ifft.default, 139 aten.fft_rfft.default, 140 aten.fft_irfft.default, 141 aten.fft_hfft.default, 142 aten.fft_ihfft.default, 143 aten.fft_fftn.default, 144 aten.fft_ifftn.default, 145 aten.fft_rfftn.default, 146 aten.fft_ihfftn.default, 147 aten.fft_irfftn.default, 148 aten.fft_hfftn.default, 149 aten.fft_fft2.default, 150 aten.fft_ifft2.default, 151 aten.fft_rfft2.default, 152 aten.fft_irfft2.default, 153 aten.fft_hfft2.default, 154 aten.fft_ihfft2.default, 155 aten.fft_fftshift.default, 156 aten.fft_ifftshift.default, 157 aten.selu.default, 158 aten.margin_ranking_loss.default, 159 aten.hinge_embedding_loss.default, 160 aten.nll_loss.default, 161 aten.prelu.default, 162 aten.relu6.default, 163 aten.pairwise_distance.default, 164 aten.pdist.default, 165 aten.special_ndtr.default, 166 aten.cummax.dimname, 167 aten.cummin.dimname, 168 aten.logcumsumexp.dimname, 169 aten.max.other, 170 aten.max.names_dim, 171 aten.min.other, 172 aten.min.names_dim, 173 aten.linalg_eigvals.default, 174 aten.median.names_dim, 175 aten.nanmedian.names_dim, 176 aten.mode.dimname, 177 aten.gather.dimname, 178 aten.sort.dimname, 179 aten.sort.dimname_stable, 180 aten.argsort.default, 181 aten.argsort.dimname, 182 aten.rrelu.default, 183 aten.conv_transpose1d.default, 184 aten.conv_transpose2d.input, 185 aten.conv_transpose3d.input, 186 aten.conv1d.default, 187 aten.conv1d.padding, 188 aten.conv3d.default, 189 aten.conv3d.padding, 190 aten.float_power.Tensor_Tensor, 191 aten.float_power.Tensor_Scalar, 192 aten.float_power.Scalar, 193 aten.ldexp.Tensor, 194 aten._version.default, 195] 196 197 198def make_test_cls_with_mocked_export( 199 cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None 200): 201 MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {}) 202 MockedTestClass.__qualname__ = MockedTestClass.__name__ 203 204 for name in dir(cls): 205 if name.startswith("test_"): 206 fn = getattr(cls, name) 207 if not callable(fn): 208 setattr(MockedTestClass, name, getattr(cls, name)) 209 continue 210 new_name = f"{name}{fn_suffix}" 211 new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn) 212 new_fn.__name__ = new_name 213 if xfail_prop is not None and hasattr(fn, xfail_prop): 214 new_fn = unittest.expectedFailure(new_fn) 215 setattr(MockedTestClass, new_name, new_fn) 216 # NB: Doesn't handle slots correctly, but whatever 217 elif not hasattr(MockedTestClass, name): 218 setattr(MockedTestClass, name, getattr(cls, name)) 219 220 return MockedTestClass 221 222 223def _make_fn_with_mocked_export(fn, mocked_export_fn): 224 @functools.wraps(fn) 225 def _fn(*args, **kwargs): 226 try: 227 from . import test_export 228 except ImportError: 229 import test_export 230 231 with patch(f"{test_export.__name__}.export", mocked_export_fn): 232 return fn(*args, **kwargs) 233 234 return _fn 235 236 237# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py 238def expectedFailureTrainingIRToRunDecomp(fn): 239 fn._expected_failure_training_ir_to_run_decomp = True 240 return fn 241 242 243# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py 244def expectedFailureTrainingIRToRunDecompNonStrict(fn): 245 fn._expected_failure_training_ir_to_run_decomp_non_strict = True 246 return fn 247 248 249# Controls tests generated in test/export/test_export_nonstrict.py 250def expectedFailureNonStrict(fn): 251 fn._expected_failure_non_strict = True 252 return fn 253 254 255# Controls tests generated in test/export/test_retraceability.py 256def expectedFailureRetraceability(fn): 257 fn._expected_failure_retrace = True 258 return fn 259 260 261# Controls tests generated in test/export/test_serdes.py 262def expectedFailureSerDer(fn): 263 fn._expected_failure_serdes = True 264 return fn 265 266 267def expectedFailureSerDerPreDispatch(fn): 268 fn._expected_failure_serdes_pre_dispatch = True 269 return fn 270 271 272def expectedFailurePreDispatchRunDecomp(fn): 273 fn._expected_failure_pre_dispatch = True 274 return fn 275