xref: /aosp_15_r20/external/pytorch/test/export/testing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2import unittest
3from unittest.mock import patch
4
5import torch
6
7
8aten = torch.ops.aten
9
10# This list is not meant to be comprehensive
11_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
12    aten.arctan2.default,
13    aten.divide.Tensor,
14    aten.divide.Scalar,
15    aten.divide.Tensor_mode,
16    aten.divide.Scalar_mode,
17    aten.multiply.Tensor,
18    aten.multiply.Scalar,
19    aten.subtract.Tensor,
20    aten.subtract.Scalar,
21    aten.true_divide.Tensor,
22    aten.true_divide.Scalar,
23    aten.greater.Tensor,
24    aten.greater.Scalar,
25    aten.greater_equal.Tensor,
26    aten.greater_equal.Scalar,
27    aten.less_equal.Tensor,
28    aten.less_equal.Scalar,
29    aten.less.Tensor,
30    aten.less.Scalar,
31    aten.not_equal.Tensor,
32    aten.not_equal.Scalar,
33    aten.cat.names,
34    aten.sum.dim_DimnameList,
35    aten.mean.names_dim,
36    aten.prod.dim_Dimname,
37    aten.all.dimname,
38    aten.norm.names_ScalarOpt_dim,
39    aten.norm.names_ScalarOpt_dim_dtype,
40    aten.var.default,
41    aten.var.dim,
42    aten.var.names_dim,
43    aten.var.correction_names,
44    aten.std.default,
45    aten.std.dim,
46    aten.std.names_dim,
47    aten.std.correction_names,
48    aten.absolute.default,
49    aten.arccos.default,
50    aten.arccosh.default,
51    aten.arcsin.default,
52    aten.arcsinh.default,
53    aten.arctan.default,
54    aten.arctanh.default,
55    aten.clip.default,
56    aten.clip.Tensor,
57    aten.fix.default,
58    aten.negative.default,
59    aten.square.default,
60    aten.size.int,
61    aten.size.Dimname,
62    aten.stride.int,
63    aten.stride.Dimname,
64    aten.repeat_interleave.self_Tensor,
65    aten.repeat_interleave.self_int,
66    aten.sym_size.int,
67    aten.sym_stride.int,
68    aten.atleast_1d.Sequence,
69    aten.atleast_2d.Sequence,
70    aten.atleast_3d.Sequence,
71    aten.linear.default,
72    aten.conv2d.default,
73    aten.conv2d.padding,
74    aten.mish_backward.default,
75    aten.silu_backward.default,
76    aten.index_add.dimname,
77    aten.pad_sequence.default,
78    aten.index_copy.dimname,
79    aten.upsample_nearest1d.vec,
80    aten.upsample_nearest2d.vec,
81    aten.upsample_nearest3d.vec,
82    aten._upsample_nearest_exact1d.vec,
83    aten._upsample_nearest_exact2d.vec,
84    aten._upsample_nearest_exact3d.vec,
85    aten.rnn_tanh.input,
86    aten.rnn_tanh.data,
87    aten.rnn_relu.input,
88    aten.rnn_relu.data,
89    aten.lstm.input,
90    aten.lstm.data,
91    aten.gru.input,
92    aten.gru.data,
93    aten._upsample_bilinear2d_aa.vec,
94    aten._upsample_bicubic2d_aa.vec,
95    aten.upsample_bilinear2d.vec,
96    aten.upsample_trilinear3d.vec,
97    aten.upsample_linear1d.vec,
98    aten.matmul.default,
99    aten.upsample_bicubic2d.vec,
100    aten.__and__.Scalar,
101    aten.__and__.Tensor,
102    aten.__or__.Tensor,
103    aten.__or__.Scalar,
104    aten.__xor__.Tensor,
105    aten.__xor__.Scalar,
106    aten.scatter.dimname_src,
107    aten.scatter.dimname_value,
108    aten.scatter_add.dimname,
109    aten.is_complex.default,
110    aten.logsumexp.names,
111    aten.where.ScalarOther,
112    aten.where.ScalarSelf,
113    aten.where.Scalar,
114    aten.where.default,
115    aten.item.default,
116    aten.any.dimname,
117    aten.std_mean.default,
118    aten.std_mean.dim,
119    aten.std_mean.names_dim,
120    aten.std_mean.correction_names,
121    aten.var_mean.default,
122    aten.var_mean.dim,
123    aten.var_mean.names_dim,
124    aten.var_mean.correction_names,
125    aten.broadcast_tensors.default,
126    aten.stft.default,
127    aten.stft.center,
128    aten.istft.default,
129    aten.index_fill.Dimname_Scalar,
130    aten.index_fill.Dimname_Tensor,
131    aten.index_select.dimname,
132    aten.diag.default,
133    aten.cumsum.dimname,
134    aten.cumprod.dimname,
135    aten.meshgrid.default,
136    aten.meshgrid.indexing,
137    aten.fft_fft.default,
138    aten.fft_ifft.default,
139    aten.fft_rfft.default,
140    aten.fft_irfft.default,
141    aten.fft_hfft.default,
142    aten.fft_ihfft.default,
143    aten.fft_fftn.default,
144    aten.fft_ifftn.default,
145    aten.fft_rfftn.default,
146    aten.fft_ihfftn.default,
147    aten.fft_irfftn.default,
148    aten.fft_hfftn.default,
149    aten.fft_fft2.default,
150    aten.fft_ifft2.default,
151    aten.fft_rfft2.default,
152    aten.fft_irfft2.default,
153    aten.fft_hfft2.default,
154    aten.fft_ihfft2.default,
155    aten.fft_fftshift.default,
156    aten.fft_ifftshift.default,
157    aten.selu.default,
158    aten.margin_ranking_loss.default,
159    aten.hinge_embedding_loss.default,
160    aten.nll_loss.default,
161    aten.prelu.default,
162    aten.relu6.default,
163    aten.pairwise_distance.default,
164    aten.pdist.default,
165    aten.special_ndtr.default,
166    aten.cummax.dimname,
167    aten.cummin.dimname,
168    aten.logcumsumexp.dimname,
169    aten.max.other,
170    aten.max.names_dim,
171    aten.min.other,
172    aten.min.names_dim,
173    aten.linalg_eigvals.default,
174    aten.median.names_dim,
175    aten.nanmedian.names_dim,
176    aten.mode.dimname,
177    aten.gather.dimname,
178    aten.sort.dimname,
179    aten.sort.dimname_stable,
180    aten.argsort.default,
181    aten.argsort.dimname,
182    aten.rrelu.default,
183    aten.conv_transpose1d.default,
184    aten.conv_transpose2d.input,
185    aten.conv_transpose3d.input,
186    aten.conv1d.default,
187    aten.conv1d.padding,
188    aten.conv3d.default,
189    aten.conv3d.padding,
190    aten.float_power.Tensor_Tensor,
191    aten.float_power.Tensor_Scalar,
192    aten.float_power.Scalar,
193    aten.ldexp.Tensor,
194    aten._version.default,
195]
196
197
198def make_test_cls_with_mocked_export(
199    cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
200):
201    MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
202    MockedTestClass.__qualname__ = MockedTestClass.__name__
203
204    for name in dir(cls):
205        if name.startswith("test_"):
206            fn = getattr(cls, name)
207            if not callable(fn):
208                setattr(MockedTestClass, name, getattr(cls, name))
209                continue
210            new_name = f"{name}{fn_suffix}"
211            new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
212            new_fn.__name__ = new_name
213            if xfail_prop is not None and hasattr(fn, xfail_prop):
214                new_fn = unittest.expectedFailure(new_fn)
215            setattr(MockedTestClass, new_name, new_fn)
216        # NB: Doesn't handle slots correctly, but whatever
217        elif not hasattr(MockedTestClass, name):
218            setattr(MockedTestClass, name, getattr(cls, name))
219
220    return MockedTestClass
221
222
223def _make_fn_with_mocked_export(fn, mocked_export_fn):
224    @functools.wraps(fn)
225    def _fn(*args, **kwargs):
226        try:
227            from . import test_export
228        except ImportError:
229            import test_export
230
231        with patch(f"{test_export.__name__}.export", mocked_export_fn):
232            return fn(*args, **kwargs)
233
234    return _fn
235
236
237# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
238def expectedFailureTrainingIRToRunDecomp(fn):
239    fn._expected_failure_training_ir_to_run_decomp = True
240    return fn
241
242
243# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
244def expectedFailureTrainingIRToRunDecompNonStrict(fn):
245    fn._expected_failure_training_ir_to_run_decomp_non_strict = True
246    return fn
247
248
249# Controls tests generated in test/export/test_export_nonstrict.py
250def expectedFailureNonStrict(fn):
251    fn._expected_failure_non_strict = True
252    return fn
253
254
255# Controls tests generated in test/export/test_retraceability.py
256def expectedFailureRetraceability(fn):
257    fn._expected_failure_retrace = True
258    return fn
259
260
261# Controls tests generated in test/export/test_serdes.py
262def expectedFailureSerDer(fn):
263    fn._expected_failure_serdes = True
264    return fn
265
266
267def expectedFailureSerDerPreDispatch(fn):
268    fn._expected_failure_serdes_pre_dispatch = True
269    return fn
270
271
272def expectedFailurePreDispatchRunDecomp(fn):
273    fn._expected_failure_pre_dispatch = True
274    return fn
275