xref: /aosp_15_r20/external/pytorch/test/onnx/pytorch_test_common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2from __future__ import annotations
3
4import functools
5import os
6import random
7import sys
8import unittest
9from enum import auto, Enum
10from typing import Optional
11
12import numpy as np
13import packaging.version
14import pytest
15
16import torch
17from torch.autograd import function
18from torch.onnx._internal import diagnostics
19from torch.testing._internal import common_utils
20
21
22pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
23sys.path.insert(-1, pytorch_test_dir)
24
25torch.set_default_dtype(torch.float)
26
27BATCH_SIZE = 2
28
29RNN_BATCH_SIZE = 7
30RNN_SEQUENCE_LENGTH = 11
31RNN_INPUT_SIZE = 5
32RNN_HIDDEN_SIZE = 3
33
34
35class TorchModelType(Enum):
36    TORCH_NN_MODULE = auto()
37    TORCH_EXPORT_EXPORTEDPROGRAM = auto()
38
39
40def _skipper(condition, reason):
41    def decorator(f):
42        @functools.wraps(f)
43        def wrapper(*args, **kwargs):
44            if condition():
45                raise unittest.SkipTest(reason)
46            return f(*args, **kwargs)
47
48        return wrapper
49
50    return decorator
51
52
53skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(), "CUDA is not available")
54
55skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"), "Skip In Travis")
56
57skipIfNoBFloat16Cuda = _skipper(
58    lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available"
59)
60
61skipIfQuantizationBackendQNNPack = _skipper(
62    lambda: torch.backends.quantized.engine == "qnnpack",
63    "Not compatible with QNNPack quantization backend",
64)
65
66
67# skips tests for all versions below min_opset_version.
68# add this wrapper to prevent running the test for opset_versions
69# smaller than `min_opset_version`.
70def skipIfUnsupportedMinOpsetVersion(min_opset_version):
71    def skip_dec(func):
72        @functools.wraps(func)
73        def wrapper(self, *args, **kwargs):
74            if self.opset_version < min_opset_version:
75                raise unittest.SkipTest(
76                    f"Unsupported opset_version: {self.opset_version} < {min_opset_version}"
77                )
78            return func(self, *args, **kwargs)
79
80        return wrapper
81
82    return skip_dec
83
84
85# skips tests for all versions above max_opset_version.
86# add this wrapper to prevent running the test for opset_versions
87# higher than `max_opset_version`.
88def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
89    def skip_dec(func):
90        @functools.wraps(func)
91        def wrapper(self, *args, **kwargs):
92            if self.opset_version > max_opset_version:
93                raise unittest.SkipTest(
94                    f"Unsupported opset_version: {self.opset_version} > {max_opset_version}"
95                )
96            return func(self, *args, **kwargs)
97
98        return wrapper
99
100    return skip_dec
101
102
103# skips tests for all opset versions.
104def skipForAllOpsetVersions():
105    def skip_dec(func):
106        @functools.wraps(func)
107        def wrapper(self, *args, **kwargs):
108            if self.opset_version:
109                raise unittest.SkipTest(
110                    "Skip verify test for unsupported opset_version"
111                )
112            return func(self, *args, **kwargs)
113
114        return wrapper
115
116    return skip_dec
117
118
119def skipTraceTest(skip_before_opset_version: Optional[int] = None, reason: str = ""):
120    """Skip tracing test for opset version less than skip_before_opset_version.
121
122    Args:
123        skip_before_opset_version: The opset version before which to skip tracing test.
124            If None, tracing test is always skipped.
125        reason: The reason for skipping tracing test.
126
127    Returns:
128        A decorator for skipping tracing test.
129    """
130
131    def skip_dec(func):
132        @functools.wraps(func)
133        def wrapper(self, *args, **kwargs):
134            if skip_before_opset_version is not None:
135                self.skip_this_opset = self.opset_version < skip_before_opset_version
136            else:
137                self.skip_this_opset = True
138            if self.skip_this_opset and not self.is_script:
139                raise unittest.SkipTest(f"Skip verify test for torch trace. {reason}")
140            return func(self, *args, **kwargs)
141
142        return wrapper
143
144    return skip_dec
145
146
147def skipScriptTest(skip_before_opset_version: Optional[int] = None, reason: str = ""):
148    """Skip scripting test for opset version less than skip_before_opset_version.
149
150    Args:
151        skip_before_opset_version: The opset version before which to skip scripting test.
152            If None, scripting test is always skipped.
153        reason: The reason for skipping scripting test.
154
155    Returns:
156        A decorator for skipping scripting test.
157    """
158
159    def skip_dec(func):
160        @functools.wraps(func)
161        def wrapper(self, *args, **kwargs):
162            if skip_before_opset_version is not None:
163                self.skip_this_opset = self.opset_version < skip_before_opset_version
164            else:
165                self.skip_this_opset = True
166            if self.skip_this_opset and self.is_script:
167                raise unittest.SkipTest(f"Skip verify test for TorchScript. {reason}")
168            return func(self, *args, **kwargs)
169
170        return wrapper
171
172    return skip_dec
173
174
175# NOTE: This decorator is currently unused, but we may want to use it in the future when
176# we have more tests that are not supported in released ORT.
177def skip_min_ort_version(reason: str, version: str, dynamic_only: bool = False):
178    def skip_dec(func):
179        @functools.wraps(func)
180        def wrapper(self, *args, **kwargs):
181            if (
182                packaging.version.parse(self.ort_version).release
183                < packaging.version.parse(version).release
184            ):
185                if dynamic_only and not self.dynamic_shapes:
186                    return func(self, *args, **kwargs)
187
188                raise unittest.SkipTest(
189                    f"ONNX Runtime version: {version} is older than required version {version}. "
190                    f"Reason: {reason}."
191                )
192            return func(self, *args, **kwargs)
193
194        return wrapper
195
196    return skip_dec
197
198
199def xfail_dynamic_fx_test(
200    error_message: str,
201    model_type: Optional[TorchModelType] = None,
202    reason: Optional[str] = None,
203):
204    """Xfail dynamic exporting test.
205
206    Args:
207        reason: The reason for xfailing dynamic exporting test.
208        model_type (TorchModelType): The model type to xfail dynamic exporting test for.
209            When None, model type is not used to xfail dynamic tests.
210
211    Returns:
212        A decorator for xfailing dynamic exporting test.
213    """
214
215    def skip_dec(func):
216        @functools.wraps(func)
217        def wrapper(self, *args, **kwargs):
218            if self.dynamic_shapes and (
219                not model_type or self.model_type == model_type
220            ):
221                return xfail(error_message, reason)(func)(self, *args, **kwargs)
222            return func(self, *args, **kwargs)
223
224        return wrapper
225
226    return skip_dec
227
228
229def skip_dynamic_fx_test(reason: str, model_type: TorchModelType = None):
230    """Skip dynamic exporting test.
231
232    Args:
233        reason: The reason for skipping dynamic exporting test.
234        model_type (TorchModelType): The model type to skip dynamic exporting test for.
235            When None, model type is not used to skip dynamic tests.
236
237    Returns:
238        A decorator for skipping dynamic exporting test.
239    """
240
241    def skip_dec(func):
242        @functools.wraps(func)
243        def wrapper(self, *args, **kwargs):
244            if self.dynamic_shapes and (
245                not model_type or self.model_type == model_type
246            ):
247                raise unittest.SkipTest(
248                    f"Skip verify dynamic shapes test for FX. {reason}"
249                )
250            return func(self, *args, **kwargs)
251
252        return wrapper
253
254    return skip_dec
255
256
257def skip_in_ci(reason: str):
258    """Skip test in CI.
259
260    Args:
261        reason: The reason for skipping test in CI.
262
263    Returns:
264        A decorator for skipping test in CI.
265    """
266
267    def skip_dec(func):
268        @functools.wraps(func)
269        def wrapper(self, *args, **kwargs):
270            if os.getenv("CI"):
271                raise unittest.SkipTest(f"Skip test in CI. {reason}")
272            return func(self, *args, **kwargs)
273
274        return wrapper
275
276    return skip_dec
277
278
279def xfail(error_message: str, reason: Optional[str] = None):
280    """Expect failure.
281
282    Args:
283        reason: The reason for expected failure.
284
285    Returns:
286        A decorator for expecting test failure.
287    """
288
289    def wrapper(func):
290        @functools.wraps(func)
291        def inner(self, *args, **kwargs):
292            try:
293                func(self, *args, **kwargs)
294            except Exception as e:
295                if isinstance(e, torch.onnx.OnnxExporterError):
296                    # diagnostic message is in the cause of the exception
297                    assert (
298                        error_message in str(e.__cause__)
299                    ), f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
300                else:
301                    assert error_message in str(
302                        e
303                    ), f"Expected error message: {error_message} NOT in {str(e)}"
304                pytest.xfail(reason if reason else f"Expected failure: {error_message}")
305            else:
306                pytest.fail("Unexpected success!")
307
308        return inner
309
310    return wrapper
311
312
313# skips tests for opset_versions listed in unsupported_opset_versions.
314# if the PyTorch test cannot be run for a specific version, add this wrapper
315# (for example, an op was modified but the change is not supported in PyTorch)
316def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
317    def skip_dec(func):
318        @functools.wraps(func)
319        def wrapper(self, *args, **kwargs):
320            if self.opset_version in unsupported_opset_versions:
321                raise unittest.SkipTest(
322                    "Skip verify test for unsupported opset_version"
323                )
324            return func(self, *args, **kwargs)
325
326        return wrapper
327
328    return skip_dec
329
330
331def skipShapeChecking(func):
332    @functools.wraps(func)
333    def wrapper(self, *args, **kwargs):
334        self.check_shape = False
335        return func(self, *args, **kwargs)
336
337    return wrapper
338
339
340def skipDtypeChecking(func):
341    @functools.wraps(func)
342    def wrapper(self, *args, **kwargs):
343        self.check_dtype = False
344        return func(self, *args, **kwargs)
345
346    return wrapper
347
348
349def xfail_if_model_type_is_exportedprogram(
350    error_message: str, reason: Optional[str] = None
351):
352    """xfail test with models using ExportedProgram as input.
353
354    Args:
355        error_message: The error message to raise when the test is xfailed.
356        reason: The reason for xfail the ONNX export test.
357
358    Returns:
359        A decorator for xfail tests.
360    """
361
362    def xfail_dec(func):
363        @functools.wraps(func)
364        def wrapper(self, *args, **kwargs):
365            if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
366                return xfail(error_message, reason)(func)(self, *args, **kwargs)
367            return func(self, *args, **kwargs)
368
369        return wrapper
370
371    return xfail_dec
372
373
374def xfail_if_model_type_is_not_exportedprogram(
375    error_message: str, reason: Optional[str] = None
376):
377    """xfail test without models using ExportedProgram as input.
378
379    Args:
380        reason: The reason for xfail the ONNX export test.
381
382    Returns:
383        A decorator for xfail tests.
384    """
385
386    def xfail_dec(func):
387        @functools.wraps(func)
388        def wrapper(self, *args, **kwargs):
389            if self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
390                return xfail(error_message, reason)(func)(self, *args, **kwargs)
391            return func(self, *args, **kwargs)
392
393        return wrapper
394
395    return xfail_dec
396
397
398def flatten(x):
399    return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
400
401
402def set_rng_seed(seed):
403    torch.manual_seed(seed)
404    random.seed(seed)
405    np.random.seed(seed)
406
407
408class ExportTestCase(common_utils.TestCase):
409    """Test case for ONNX export.
410
411    Any test case that tests functionalities under torch.onnx should inherit from this class.
412    """
413
414    def setUp(self):
415        super().setUp()
416        # TODO(#88264): Flaky test failures after changing seed.
417        set_rng_seed(0)
418        if torch.cuda.is_available():
419            torch.cuda.manual_seed_all(0)
420        diagnostics.engine.clear()
421