1# Owner(s): ["module: onnx"] 2from __future__ import annotations 3 4import functools 5import os 6import random 7import sys 8import unittest 9from enum import auto, Enum 10from typing import Optional 11 12import numpy as np 13import packaging.version 14import pytest 15 16import torch 17from torch.autograd import function 18from torch.onnx._internal import diagnostics 19from torch.testing._internal import common_utils 20 21 22pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 23sys.path.insert(-1, pytorch_test_dir) 24 25torch.set_default_dtype(torch.float) 26 27BATCH_SIZE = 2 28 29RNN_BATCH_SIZE = 7 30RNN_SEQUENCE_LENGTH = 11 31RNN_INPUT_SIZE = 5 32RNN_HIDDEN_SIZE = 3 33 34 35class TorchModelType(Enum): 36 TORCH_NN_MODULE = auto() 37 TORCH_EXPORT_EXPORTEDPROGRAM = auto() 38 39 40def _skipper(condition, reason): 41 def decorator(f): 42 @functools.wraps(f) 43 def wrapper(*args, **kwargs): 44 if condition(): 45 raise unittest.SkipTest(reason) 46 return f(*args, **kwargs) 47 48 return wrapper 49 50 return decorator 51 52 53skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(), "CUDA is not available") 54 55skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"), "Skip In Travis") 56 57skipIfNoBFloat16Cuda = _skipper( 58 lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available" 59) 60 61skipIfQuantizationBackendQNNPack = _skipper( 62 lambda: torch.backends.quantized.engine == "qnnpack", 63 "Not compatible with QNNPack quantization backend", 64) 65 66 67# skips tests for all versions below min_opset_version. 68# add this wrapper to prevent running the test for opset_versions 69# smaller than `min_opset_version`. 70def skipIfUnsupportedMinOpsetVersion(min_opset_version): 71 def skip_dec(func): 72 @functools.wraps(func) 73 def wrapper(self, *args, **kwargs): 74 if self.opset_version < min_opset_version: 75 raise unittest.SkipTest( 76 f"Unsupported opset_version: {self.opset_version} < {min_opset_version}" 77 ) 78 return func(self, *args, **kwargs) 79 80 return wrapper 81 82 return skip_dec 83 84 85# skips tests for all versions above max_opset_version. 86# add this wrapper to prevent running the test for opset_versions 87# higher than `max_opset_version`. 88def skipIfUnsupportedMaxOpsetVersion(max_opset_version): 89 def skip_dec(func): 90 @functools.wraps(func) 91 def wrapper(self, *args, **kwargs): 92 if self.opset_version > max_opset_version: 93 raise unittest.SkipTest( 94 f"Unsupported opset_version: {self.opset_version} > {max_opset_version}" 95 ) 96 return func(self, *args, **kwargs) 97 98 return wrapper 99 100 return skip_dec 101 102 103# skips tests for all opset versions. 104def skipForAllOpsetVersions(): 105 def skip_dec(func): 106 @functools.wraps(func) 107 def wrapper(self, *args, **kwargs): 108 if self.opset_version: 109 raise unittest.SkipTest( 110 "Skip verify test for unsupported opset_version" 111 ) 112 return func(self, *args, **kwargs) 113 114 return wrapper 115 116 return skip_dec 117 118 119def skipTraceTest(skip_before_opset_version: Optional[int] = None, reason: str = ""): 120 """Skip tracing test for opset version less than skip_before_opset_version. 121 122 Args: 123 skip_before_opset_version: The opset version before which to skip tracing test. 124 If None, tracing test is always skipped. 125 reason: The reason for skipping tracing test. 126 127 Returns: 128 A decorator for skipping tracing test. 129 """ 130 131 def skip_dec(func): 132 @functools.wraps(func) 133 def wrapper(self, *args, **kwargs): 134 if skip_before_opset_version is not None: 135 self.skip_this_opset = self.opset_version < skip_before_opset_version 136 else: 137 self.skip_this_opset = True 138 if self.skip_this_opset and not self.is_script: 139 raise unittest.SkipTest(f"Skip verify test for torch trace. {reason}") 140 return func(self, *args, **kwargs) 141 142 return wrapper 143 144 return skip_dec 145 146 147def skipScriptTest(skip_before_opset_version: Optional[int] = None, reason: str = ""): 148 """Skip scripting test for opset version less than skip_before_opset_version. 149 150 Args: 151 skip_before_opset_version: The opset version before which to skip scripting test. 152 If None, scripting test is always skipped. 153 reason: The reason for skipping scripting test. 154 155 Returns: 156 A decorator for skipping scripting test. 157 """ 158 159 def skip_dec(func): 160 @functools.wraps(func) 161 def wrapper(self, *args, **kwargs): 162 if skip_before_opset_version is not None: 163 self.skip_this_opset = self.opset_version < skip_before_opset_version 164 else: 165 self.skip_this_opset = True 166 if self.skip_this_opset and self.is_script: 167 raise unittest.SkipTest(f"Skip verify test for TorchScript. {reason}") 168 return func(self, *args, **kwargs) 169 170 return wrapper 171 172 return skip_dec 173 174 175# NOTE: This decorator is currently unused, but we may want to use it in the future when 176# we have more tests that are not supported in released ORT. 177def skip_min_ort_version(reason: str, version: str, dynamic_only: bool = False): 178 def skip_dec(func): 179 @functools.wraps(func) 180 def wrapper(self, *args, **kwargs): 181 if ( 182 packaging.version.parse(self.ort_version).release 183 < packaging.version.parse(version).release 184 ): 185 if dynamic_only and not self.dynamic_shapes: 186 return func(self, *args, **kwargs) 187 188 raise unittest.SkipTest( 189 f"ONNX Runtime version: {version} is older than required version {version}. " 190 f"Reason: {reason}." 191 ) 192 return func(self, *args, **kwargs) 193 194 return wrapper 195 196 return skip_dec 197 198 199def xfail_dynamic_fx_test( 200 error_message: str, 201 model_type: Optional[TorchModelType] = None, 202 reason: Optional[str] = None, 203): 204 """Xfail dynamic exporting test. 205 206 Args: 207 reason: The reason for xfailing dynamic exporting test. 208 model_type (TorchModelType): The model type to xfail dynamic exporting test for. 209 When None, model type is not used to xfail dynamic tests. 210 211 Returns: 212 A decorator for xfailing dynamic exporting test. 213 """ 214 215 def skip_dec(func): 216 @functools.wraps(func) 217 def wrapper(self, *args, **kwargs): 218 if self.dynamic_shapes and ( 219 not model_type or self.model_type == model_type 220 ): 221 return xfail(error_message, reason)(func)(self, *args, **kwargs) 222 return func(self, *args, **kwargs) 223 224 return wrapper 225 226 return skip_dec 227 228 229def skip_dynamic_fx_test(reason: str, model_type: TorchModelType = None): 230 """Skip dynamic exporting test. 231 232 Args: 233 reason: The reason for skipping dynamic exporting test. 234 model_type (TorchModelType): The model type to skip dynamic exporting test for. 235 When None, model type is not used to skip dynamic tests. 236 237 Returns: 238 A decorator for skipping dynamic exporting test. 239 """ 240 241 def skip_dec(func): 242 @functools.wraps(func) 243 def wrapper(self, *args, **kwargs): 244 if self.dynamic_shapes and ( 245 not model_type or self.model_type == model_type 246 ): 247 raise unittest.SkipTest( 248 f"Skip verify dynamic shapes test for FX. {reason}" 249 ) 250 return func(self, *args, **kwargs) 251 252 return wrapper 253 254 return skip_dec 255 256 257def skip_in_ci(reason: str): 258 """Skip test in CI. 259 260 Args: 261 reason: The reason for skipping test in CI. 262 263 Returns: 264 A decorator for skipping test in CI. 265 """ 266 267 def skip_dec(func): 268 @functools.wraps(func) 269 def wrapper(self, *args, **kwargs): 270 if os.getenv("CI"): 271 raise unittest.SkipTest(f"Skip test in CI. {reason}") 272 return func(self, *args, **kwargs) 273 274 return wrapper 275 276 return skip_dec 277 278 279def xfail(error_message: str, reason: Optional[str] = None): 280 """Expect failure. 281 282 Args: 283 reason: The reason for expected failure. 284 285 Returns: 286 A decorator for expecting test failure. 287 """ 288 289 def wrapper(func): 290 @functools.wraps(func) 291 def inner(self, *args, **kwargs): 292 try: 293 func(self, *args, **kwargs) 294 except Exception as e: 295 if isinstance(e, torch.onnx.OnnxExporterError): 296 # diagnostic message is in the cause of the exception 297 assert ( 298 error_message in str(e.__cause__) 299 ), f"Expected error message: {error_message} NOT in {str(e.__cause__)}" 300 else: 301 assert error_message in str( 302 e 303 ), f"Expected error message: {error_message} NOT in {str(e)}" 304 pytest.xfail(reason if reason else f"Expected failure: {error_message}") 305 else: 306 pytest.fail("Unexpected success!") 307 308 return inner 309 310 return wrapper 311 312 313# skips tests for opset_versions listed in unsupported_opset_versions. 314# if the PyTorch test cannot be run for a specific version, add this wrapper 315# (for example, an op was modified but the change is not supported in PyTorch) 316def skipIfUnsupportedOpsetVersion(unsupported_opset_versions): 317 def skip_dec(func): 318 @functools.wraps(func) 319 def wrapper(self, *args, **kwargs): 320 if self.opset_version in unsupported_opset_versions: 321 raise unittest.SkipTest( 322 "Skip verify test for unsupported opset_version" 323 ) 324 return func(self, *args, **kwargs) 325 326 return wrapper 327 328 return skip_dec 329 330 331def skipShapeChecking(func): 332 @functools.wraps(func) 333 def wrapper(self, *args, **kwargs): 334 self.check_shape = False 335 return func(self, *args, **kwargs) 336 337 return wrapper 338 339 340def skipDtypeChecking(func): 341 @functools.wraps(func) 342 def wrapper(self, *args, **kwargs): 343 self.check_dtype = False 344 return func(self, *args, **kwargs) 345 346 return wrapper 347 348 349def xfail_if_model_type_is_exportedprogram( 350 error_message: str, reason: Optional[str] = None 351): 352 """xfail test with models using ExportedProgram as input. 353 354 Args: 355 error_message: The error message to raise when the test is xfailed. 356 reason: The reason for xfail the ONNX export test. 357 358 Returns: 359 A decorator for xfail tests. 360 """ 361 362 def xfail_dec(func): 363 @functools.wraps(func) 364 def wrapper(self, *args, **kwargs): 365 if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM: 366 return xfail(error_message, reason)(func)(self, *args, **kwargs) 367 return func(self, *args, **kwargs) 368 369 return wrapper 370 371 return xfail_dec 372 373 374def xfail_if_model_type_is_not_exportedprogram( 375 error_message: str, reason: Optional[str] = None 376): 377 """xfail test without models using ExportedProgram as input. 378 379 Args: 380 reason: The reason for xfail the ONNX export test. 381 382 Returns: 383 A decorator for xfail tests. 384 """ 385 386 def xfail_dec(func): 387 @functools.wraps(func) 388 def wrapper(self, *args, **kwargs): 389 if self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM: 390 return xfail(error_message, reason)(func)(self, *args, **kwargs) 391 return func(self, *args, **kwargs) 392 393 return wrapper 394 395 return xfail_dec 396 397 398def flatten(x): 399 return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x)) 400 401 402def set_rng_seed(seed): 403 torch.manual_seed(seed) 404 random.seed(seed) 405 np.random.seed(seed) 406 407 408class ExportTestCase(common_utils.TestCase): 409 """Test case for ONNX export. 410 411 Any test case that tests functionalities under torch.onnx should inherit from this class. 412 """ 413 414 def setUp(self): 415 super().setUp() 416 # TODO(#88264): Flaky test failures after changing seed. 417 set_rng_seed(0) 418 if torch.cuda.is_available(): 419 torch.cuda.manual_seed_all(0) 420 diagnostics.engine.clear() 421