1# mypy: ignore-errors 2 3import torch 4import unittest 5from copy import deepcopy 6from enum import Enum 7from functools import wraps, partial 8from itertools import chain, product 9import itertools 10import math 11import torch.nn.functional as F 12from torch.nn.utils.rnn import pack_padded_sequence 13from torch.testing import make_tensor 14from torch.testing._internal.common_cuda import TEST_CUDNN 15from torch.testing._internal.common_dtype import ( 16 floating_types, floating_and_complex_types_and, get_all_fp_dtypes) 17from torch.testing._internal.common_device_type import ( 18 _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol, 19 skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, 20 skipCUDAVersionIn) 21from torch.testing._internal.common_methods_invocations import DecorateInfo 22from torch.testing._internal.common_nn import ( 23 cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference, 24 hingeembeddingloss_reference, huberloss_reference, kldivloss_reference, 25 marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference, 26 nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction) 27from torch.testing._internal.common_utils import ( 28 freeze_rng_state, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS, 29 skipIfTorchDynamo) 30from types import ModuleType 31from typing import List, Tuple, Type, Set, Dict 32import operator 33 34# List of all namespaces containing modules to test. 35MODULE_NAMESPACES: List[ModuleType] = [ 36 torch.nn.modules, 37 torch.ao.nn.qat.modules, 38 torch.ao.nn.quantizable.modules, 39 torch.ao.nn.quantized.modules, 40 torch.ao.nn.quantized.modules, 41] 42 43# Modules that shouldn't be tested for one reason or another. 44MODULES_TO_SKIP: Set[Type] = { 45 torch.nn.Module, # abstract base class 46 torch.nn.Container, # deprecated 47 torch.nn.NLLLoss2d, # deprecated 48 torch.ao.nn.quantized.MaxPool2d, # aliases to nn.MaxPool2d 49 torch.ao.nn.quantized.MaxPool2d, # aliases to nn.MaxPool2d 50} 51 52# List of all module classes to test. 53MODULE_CLASSES: List[Type] = list(chain(*[ 54 [getattr(namespace, module_name) for module_name in namespace.__all__] # type: ignore[attr-defined] 55 for namespace in MODULE_NAMESPACES])) 56MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP] 57 58# Dict of module class -> common name. Useful for making test names more intuitive. 59# Example: torch.nn.modules.linear.Linear -> "nn.Linear" 60MODULE_CLASS_NAMES: Dict[Type, str] = {} 61for namespace in MODULE_NAMESPACES: 62 for module_name in namespace.__all__: # type: ignore[attr-defined] 63 module_cls = getattr(namespace, module_name) 64 namespace_name = namespace.__name__.replace('torch.', '').replace('.modules', '') 65 66 # Deal with any aliases by preferring earlier names. 67 if module_cls not in MODULE_CLASS_NAMES: 68 MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}' 69 70 71# Specifies the modes (i.e. train, eval) to test over. 72TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval')) 73 74 75class modules(_TestParametrizer): 76 """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """ 77 78 def __init__(self, module_info_iterable, allowed_dtypes=None, 79 train_eval_mode=TrainEvalMode.train_and_eval, skip_if_dynamo=True): 80 self.module_info_list = list(module_info_iterable) 81 self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None 82 self.train_eval_mode = train_eval_mode 83 self.skip_if_dynamo = skip_if_dynamo 84 85 def _get_training_flags(self, module_info): 86 training_flags = [] 87 if (self.train_eval_mode == TrainEvalMode.train_only or 88 self.train_eval_mode == TrainEvalMode.train_and_eval): 89 training_flags.append(True) 90 91 if (self.train_eval_mode == TrainEvalMode.eval_only or 92 self.train_eval_mode == TrainEvalMode.train_and_eval): 93 training_flags.append(False) 94 95 # If train and eval modes don't differ for the module, don't bother using more than one. 96 if not module_info.train_and_eval_differ: 97 training_flags = training_flags[:1] 98 99 return training_flags 100 101 def _parametrize_test(self, test, generic_cls, device_cls): 102 if device_cls is None: 103 raise RuntimeError('The @modules decorator is only intended to be used in a device-specific ' 104 'context; use it with instantiate_device_type_tests() instead of ' 105 'instantiate_parametrized_tests()') 106 107 for module_info in self.module_info_list: 108 dtypes = set(module_info.supported_dtypes(device_cls.device_type)) 109 if self.allowed_dtypes is not None: 110 dtypes = dtypes.intersection(self.allowed_dtypes) 111 112 training_flags = self._get_training_flags(module_info) 113 for (training, dtype) in product(training_flags, dtypes): 114 # Construct the test name; device / dtype parts are handled outside. 115 # See [Note: device and dtype suffix placement] 116 test_name = module_info.formatted_name 117 if len(training_flags) > 1: 118 test_name += f"_{'train_mode' if training else 'eval_mode'}" 119 120 # Construct parameter kwargs to pass to the test. 121 param_kwargs = {'module_info': module_info} 122 _update_param_kwargs(param_kwargs, 'dtype', dtype) 123 _update_param_kwargs(param_kwargs, 'training', training) 124 125 try: 126 127 @wraps(test) 128 def test_wrapper(*args, **kwargs): 129 return test(*args, **kwargs) 130 131 if self.skip_if_dynamo and not torch.testing._internal.common_utils.TEST_WITH_TORCHINDUCTOR: 132 test_wrapper = skipIfTorchDynamo("Policy: we don't run ModuleInfo tests w/ Dynamo")(test_wrapper) 133 134 decorator_fn = partial(module_info.get_decorators, generic_cls.__name__, 135 test.__name__, device_cls.device_type, dtype) 136 137 yield (test_wrapper, test_name, param_kwargs, decorator_fn) 138 except Exception as ex: 139 # Provides an error message for debugging before rethrowing the exception 140 print(f"Failed to instantiate {test_name} for module {module_info.name}!") 141 raise ex 142 143 144def get_module_common_name(module_cls): 145 if module_cls in MODULE_CLASS_NAMES: 146 # Example: "nn.Linear" 147 return MODULE_CLASS_NAMES[module_cls] 148 else: 149 return module_cls.__name__ 150 151 152class FunctionInput: 153 """ Contains args and kwargs to pass as input to a function. """ 154 __slots__ = ['args', 'kwargs'] 155 156 def __init__(self, *args, **kwargs): 157 self.args = args 158 self.kwargs = kwargs 159 160 161class ModuleInput: 162 """ Contains args / kwargs for module instantiation + forward pass. """ 163 __slots__ = ['constructor_input', 'forward_input', 'desc', 'reference_fn'] 164 165 def __init__(self, constructor_input, forward_input=None, desc='', reference_fn=None): 166 self.constructor_input = constructor_input # Inputs to pass during construction 167 self.forward_input = forward_input # Inputs to pass to forward() 168 self.desc = desc # Description for this set of inputs 169 self.reference_fn = reference_fn # Reference with signature: reference_fn(module, parameters, *args, **kwargs) 170 171 if reference_fn is not None: 172 173 @wraps(reference_fn) 174 def copy_reference_fn(m, *args, **kwargs): 175 # Copy inputs to avoid undesired side effects from calling the reference. 176 args, kwargs = deepcopy(args), deepcopy(kwargs) 177 178 # Note that module parameters are passed in for convenience. 179 return reference_fn(m, list(m.parameters()), *args, **kwargs) 180 181 self.reference_fn = copy_reference_fn 182 183class ModuleErrorEnum(Enum): 184 """ Enumerates when error is raised when testing modules. """ 185 CONSTRUCTION_ERROR = 0 186 FORWARD_ERROR = 1 187 188class ErrorModuleInput: 189 """ 190 A ModuleInput that will cause the operation to throw an error plus information 191 about the resulting error. 192 """ 193 194 __slots__ = ["module_error_input", "error_on", "error_type", "error_regex"] 195 196 def __init__(self, 197 module_error_input, 198 *, 199 error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, 200 error_type=RuntimeError, 201 error_regex): 202 self.module_error_input = module_error_input 203 self.error_on = error_on 204 self.error_type = error_type 205 self.error_regex = error_regex 206 207 208class ModuleInfo: 209 """ Module information to be used in testing. """ 210 211 def __init__(self, 212 module_cls, # Class object for the module under test 213 *, 214 module_inputs_func, # Function to generate module inputs 215 skips=(), # Indicates which tests to skip 216 decorators=None, # Additional decorators to apply to generated tests 217 dtypes=floating_types(), # dtypes this function is expected to work with 218 dtypesIfMPS=(torch.float16, torch.float32,), # dtypes this function is expected to work with on MPS 219 dtypesIfHpu=(torch.bfloat16, torch.float32,), 220 supports_gradgrad=True, # whether the op supports second order gradients 221 gradcheck_nondet_tol=0.0, # tolerance for nondeterminism while performing gradcheck 222 module_memformat_affects_out=False, # whether converting module to channels last will generate 223 # channels last output 224 train_and_eval_differ=False, # whether the module has differing behavior between train and eval 225 module_error_inputs_func=None, # Function to generate module inputs that error 226 ): 227 self.module_cls = module_cls 228 self.module_inputs_func = module_inputs_func 229 self.decorators = (*(decorators if decorators else []), *(skips if skips else [])) 230 self.dtypes = dtypes 231 self.dtypesIfMPS = dtypesIfMPS 232 self.dtypesIfHpu = dtypesIfHpu 233 self.supports_gradgrad = supports_gradgrad 234 self.gradcheck_nondet_tol = gradcheck_nondet_tol 235 self.module_memformat_affects_out = module_memformat_affects_out 236 self.train_and_eval_differ = train_and_eval_differ 237 self.module_error_inputs_func = module_error_inputs_func 238 self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin) 239 240 def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): 241 result = [] 242 for decorator in self.decorators: 243 if isinstance(decorator, DecorateInfo): 244 if decorator.is_active(test_class, test_name, device, dtype, param_kwargs): 245 result.extend(decorator.decorators) 246 else: 247 result.append(decorator) 248 return result 249 250 def supported_dtypes(self, device_type): 251 if device_type == 'mps': 252 return self.dtypesIfMPS 253 elif device_type == 'hpu': 254 return self.dtypesIfHpu 255 else: 256 return self.dtypes 257 258 @property 259 def name(self): 260 return get_module_common_name(self.module_cls) 261 262 @property 263 def formatted_name(self): 264 return self.name.replace('.', '_') 265 266# Start of module inputs functions. 267 268def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs): 269 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 270 271 module_inputs = [ 272 ModuleInput(constructor_input=FunctionInput(10, 8), 273 forward_input=FunctionInput(input=make_input((4, 10))), 274 reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)), 275 ModuleInput(constructor_input=FunctionInput(10, 8, bias=False), 276 forward_input=FunctionInput(make_input((4, 10))), 277 desc='no_bias', 278 reference_fn=lambda m, p, i: torch.mm(i, p[0].t())), 279 ModuleInput(constructor_input=FunctionInput(3, 5), 280 forward_input=FunctionInput(make_input(3)), 281 desc='no_batch_dim', 282 reference_fn=lambda m, p, i: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1]) 283 ] 284 285 return module_inputs 286 287 288def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs): 289 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 290 291 def bilinear_reference_fn(m, p, x1, x2, bias=True): 292 result = torch.einsum('bn,anm,bm->ba', x1, p[0], x2) 293 if bias: 294 if x1.shape[0] == 1: 295 result = result.view(-1) + p[1] 296 else: 297 result = result + p[1].view(1, -1).expand(x1.shape[0], p[0].shape[0]) 298 return result 299 300 module_inputs = [ 301 ModuleInput(constructor_input=FunctionInput(2, 3, 4), 302 forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))), 303 reference_fn=bilinear_reference_fn), 304 ModuleInput(constructor_input=FunctionInput(2, 3, 4, bias=False), 305 forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))), 306 desc='no_bias', 307 reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2, bias=False)), 308 ModuleInput(constructor_input=FunctionInput(2, 3, 4), 309 forward_input=FunctionInput(make_input(2), make_input(3)), 310 desc='no_batch_dim', 311 reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1.view(1, -1), x2.view(1, -1))), 312 ] 313 314 return module_inputs 315 316 317def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs): 318 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 319 320 cases: List[Tuple[str, dict]] = [ 321 ('', {}), 322 ('reduction_sum', {'reduction': 'sum'}), 323 ('reduction_batchmean', {'reduction': 'batchmean'}), 324 ('reduction_none', {'reduction': 'none'}), 325 ('log_target', {'log_target': True}) 326 ] 327 328 module_inputs = [] 329 for desc, constructor_kwargs in cases: 330 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 331 return kldivloss_reference(i, t, **constructor_kwargs) 332 333 input = make_input((10, 10)).log() 334 target = make_input((10, 10)) if kwargs.get('log_target', False) else make_input((10, 10)).log() 335 module_inputs.append( 336 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 337 forward_input=FunctionInput(input, target), 338 desc=desc, 339 reference_fn=reference_fn) 340 ) 341 342 scalar_input = make_input(()).log() 343 scalar_target = make_input(()) if kwargs.get('log_target', False) else make_input(()).log() 344 module_inputs.append( 345 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 346 forward_input=FunctionInput(scalar_input, scalar_input), 347 desc='scalar_' + desc, 348 reference_fn=reference_fn) 349 ) 350 351 return module_inputs 352 353 354def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs): 355 def make_input(shape, device=device, dtype=dtype, requires_grad=requires_grad): 356 return make_tensor(shape, device=device, dtype=dtype, 357 requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad) 358 make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 359 360 cases: List[Tuple[str, dict]] = [ 361 ('', {}), 362 ('reduction_sum', {'reduction': 'sum'}), 363 ('reduction_none', {'reduction': 'none'}), 364 ('ignore_index', {'ignore_index': 2}), 365 ('weights', {'weight': make_weight(4).abs()}), 366 ('weights_ignore_index', {'weight': make_weight(4).abs(), 'ignore_index': 2}), 367 ('weights_ignore_index_neg', {'weight': make_weight(4).abs(), 'ignore_index': -1}) 368 ] 369 370 # TODO: Uncomment when negative weights is supported. 371 # negative_weight = make_weight(10) 372 # negative_weight[0] = -1 373 # cases.append(('weights_negative', {'weight': negative_weight})) 374 module_inputs = [] 375 for desc, constructor_kwargs in cases: 376 377 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 378 return nllloss_reference(i, t, **constructor_kwargs) 379 380 module_inputs.append( 381 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 382 forward_input=FunctionInput(make_input((15, 4)), 383 torch.empty(15, device=device).uniform_().mul(4).floor().long()), 384 desc=desc, 385 reference_fn=reference_fn) 386 ) 387 388 def nd_reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 389 return nlllossNd_reference(i, t, **constructor_kwargs) 390 391 module_inputs.append( 392 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 393 forward_input=FunctionInput( 394 make_input((2, 4, 5, 5)), 395 torch.empty(2, 5, 5, device=device).uniform_().mul(4).floor().long()), 396 desc=f"nd_{desc}", 397 reference_fn=nd_reference_fn) 398 ) 399 400 module_inputs.append( 401 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 402 forward_input=FunctionInput( 403 make_input((2, 4, 5, 5, 2, 2)), 404 torch.empty(2, 5, 5, 2, 2, device=device).uniform_().mul(4).floor().long()), 405 desc=f"higher_dim_{desc}", 406 reference_fn=nd_reference_fn) 407 ) 408 409 module_inputs.append( 410 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 411 forward_input=FunctionInput( 412 make_input((2, 4, 5)), 413 torch.empty(2, 5, device=device).uniform_().mul(4).floor().long()), 414 desc=f"3d_{desc}", 415 reference_fn=nd_reference_fn) 416 ) 417 418 return module_inputs 419 420 421def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs): 422 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 423 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 424 425 cases: List[Tuple[str, dict]] = [ 426 ('', {}), 427 ('reduction_sum', {'reduction': 'sum'}), 428 ('reduction_mean', {'reduction': 'mean'}), 429 ('reduction_none', {'reduction': 'none'}), 430 ] 431 432 module_inputs = [] 433 for desc, constructor_kwargs in cases: 434 module_inputs.append( 435 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 436 forward_input=FunctionInput(make_input(3), 437 make_target(3), 438 make_input(1).abs()), 439 desc=desc, 440 reference_fn=no_batch_dim_reference_fn) 441 ) 442 443 return module_inputs 444 445 446def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs): 447 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 448 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 449 450 cases: List[Tuple[str, dict]] = [ 451 ('', {}), 452 ('reduction_sum', {'reduction': 'sum'}), 453 ('reduction_mean', {'reduction': 'mean'}), 454 ('reduction_none', {'reduction': 'none'}), 455 ('full', {'full': True}), 456 ('no_log_input', {'log_input': False}), 457 ('full_no_log_input', {'full': True, 'log_input': False}), 458 ] 459 460 def poissonnllloss_reference_fn(i, t, log_input=True, full=False, reduction='mean', eps=1e-8): 461 if log_input: 462 result = i.exp() - t.mul(i) 463 else: 464 result = i - t.mul((i + eps).log()) 465 466 if full: 467 result += (t.mul(t.log()) - t + 0.5 * (2. * math.pi * t).log()).masked_fill(t <= 1, 0) 468 469 if reduction == 'none': 470 return result 471 elif reduction == 'mean': 472 return result.sum() / i.numel() 473 else: 474 return result.sum() 475 476 module_inputs = [] 477 for desc, constructor_kwargs in cases: 478 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 479 return poissonnllloss_reference_fn(i, t, **constructor_kwargs) 480 481 log_input = constructor_kwargs.get('log_input', True) 482 input = make_input((2, 3, 4, 5)) if log_input else make_input((2, 3, 4, 5)).abs().add(0.001) 483 module_inputs.append( 484 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 485 forward_input=FunctionInput(input, 486 make_target((2, 3, 4, 5)).floor_().abs_()), 487 desc=desc, 488 reference_fn=reference_fn) 489 ) 490 491 return module_inputs 492 493 494def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, training, **kwargs): 495 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 496 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 497 498 cases: List[Tuple[str, dict]] = [ 499 ('', {}), 500 ('reduction_sum', {'reduction': 'sum'}), 501 ('reduction_mean', {'reduction': 'mean'}), 502 ('reduction_none', {'reduction': 'none'}), 503 ] 504 505 def mse_loss_reference_fn(m, p, i, t, reduction='mean'): 506 if reduction == 'none': 507 return (i - t).pow(2) 508 elif reduction == 'mean': 509 return (i - t).pow(2).sum() / i.numel() 510 else: 511 return (i - t).pow(2).sum() 512 513 module_inputs = [] 514 for desc, constructor_kwargs in cases: 515 module_inputs.append( 516 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 517 forward_input=FunctionInput(make_input((2, 3, 4, 5)), 518 make_target((2, 3, 4, 5))), 519 desc=desc, 520 reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs)) 521 ) 522 module_inputs.append( 523 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 524 forward_input=FunctionInput(make_input(()), 525 make_target(())), 526 desc=f'{desc}_scalar', 527 reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs)) 528 ) 529 530 return module_inputs 531 532 533def no_batch_dim_reference_fn(m, p, *args, **kwargs): 534 """Reference function for modules supporting no batch dimensions. 535 536 Unbatched inputs are unsqueezed to form a 537 single batch input before passing them to the module. 538 The output is squeezed to compare with the 539 output of unbatched input to the module. 540 541 Currently it only supports modules which return a single Tensor as output. 542 You can bind the following kwargs. 543 Kwargs: 544 batch_first[bool] : If True, all the Tensors in `args` while be unsqueezed at dim `0` . 545 and output will be squeezed at dim `0` else dim `1` for both. 546 kwargs_to_batchify[dict] : Dictionary specifying the name of the argument and dimension to unsqueeze. 547 Useful if there are few arguments whose batch dimension are different 548 from the ones selected by `batch_first`. 549 is_criterion[bool] : Specify if the module is a criterion and handle the reduction for output accordingly. 550 """ 551 def get_and_pop(key, default): 552 v = kwargs.get(key, default) 553 if key in kwargs: 554 kwargs.pop(key) 555 return v 556 557 batch_dim = 0 if get_and_pop('batch_first', True) else 1 558 kwargs_to_batchify = get_and_pop('kwargs_to_batchify', None) 559 is_criterion = get_and_pop('is_criterion', False) 560 561 if kwargs_to_batchify is not None: 562 assert isinstance(kwargs_to_batchify, dict) 563 for k, v in kwargs.items(): 564 if k in kwargs_to_batchify and v is not None: 565 bdim = kwargs_to_batchify[k] 566 kwargs[k] = v.unsqueeze(bdim) 567 568 single_batch_input_args = [input.unsqueeze(batch_dim) for input in args] 569 with freeze_rng_state(): 570 output = m(*single_batch_input_args, **kwargs).squeeze(batch_dim) 571 572 if is_criterion: 573 reduction = get_reduction(m) 574 if reduction == 'none': 575 return output.squeeze(0) 576 return output 577 578 579def no_batch_dim_reference_mha(m, p, *args, **kwargs): 580 """Reference function for MultiheadAttention supporting no batch dimensions. 581 582 Unbatched inputs are unsqueezed to form a 583 single batch input before passing them to the module. 584 The output is squeezed to compare with the 585 output of unbatched input to the module. 586 """ 587 batch_dim = 0 if kwargs.get('batch_first', True) else 1 588 if 'batch_first' in kwargs: 589 kwargs.pop('batch_first') 590 if 'key_padding_mask' in kwargs and kwargs['key_padding_mask'] is not None: 591 kwargs['key_padding_mask'] = kwargs['key_padding_mask'].unsqueeze(0) 592 single_batch_input_args = [input.unsqueeze(batch_dim) for input in args] 593 with freeze_rng_state(): 594 output = m(*single_batch_input_args, **kwargs) 595 return (output[0].squeeze(batch_dim), output[1].squeeze(0)) 596 597 598def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs): 599 """Reference function for RNN and GRU supporting no batch dimensions. 600 601 Unbatched inputs are unsqueezed to form a 602 single batch input before passing them to the module. 603 The output is squeezed to compare with the 604 output of unbatched input to the module. 605 """ 606 if len(args) == 1: 607 inp, = args 608 h = None 609 elif len(args) == 2: 610 inp, h = args 611 h = h.unsqueeze(1) 612 613 batch_dim = 0 if kwargs['batch_first'] else 1 614 kwargs.pop('batch_first') 615 inp = inp.unsqueeze(batch_dim) 616 single_batch_input_args = (inp, h) 617 with freeze_rng_state(): 618 output = m(*single_batch_input_args, **kwargs) 619 return (output[0].squeeze(batch_dim), output[1].squeeze(1)) 620 621 622def no_batch_dim_reference_lstm(m, p, *args, **kwargs): 623 """Reference function for LSTM supporting no batch dimensions. 624 625 Unbatched inputs are unsqueezed to form a 626 single batch input before passing them to the module. 627 The output is squeezed to compare with the 628 output of unbatched input to the module. 629 """ 630 if len(args) == 1: 631 inp, = args 632 h = None 633 elif len(args) == 2: 634 inp, h = args 635 h = (h[0].unsqueeze(1), h[1].unsqueeze(1)) 636 637 batch_dim = 0 if kwargs['batch_first'] else 1 638 kwargs.pop('batch_first') 639 inp = inp.unsqueeze(batch_dim) 640 single_batch_input_args = (inp, h) 641 with freeze_rng_state(): 642 output = m(*single_batch_input_args, **kwargs) 643 return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1))) 644 645 646def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs): 647 """Reference function for LSTMCell supporting no batch dimensions. 648 649 The module is passed the input and target in batched form with a single item. 650 The output is squeezed to compare with the no-batch input. 651 """ 652 inp, (h, c) = args 653 single_batch_input_args = (inp.unsqueeze(0), (h.unsqueeze(0), c.unsqueeze(0))) 654 with freeze_rng_state(): 655 output = m(*single_batch_input_args, **kwargs) 656 return (output[0].squeeze(0), output[1].squeeze(0)) 657 658 659def generate_regression_criterion_inputs(make_input): 660 return [ 661 ModuleInput( 662 constructor_input=FunctionInput(reduction=reduction), 663 forward_input=FunctionInput(make_input((4, )), make_input(4,)), 664 reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True), 665 desc=f'no_batch_dim_{reduction}' 666 ) for reduction in ['none', 'mean', 'sum']] 667 668 669def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs): 670 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 671 672 return [ 673 ModuleInput(constructor_input=FunctionInput(kernel_size=2), 674 forward_input=FunctionInput(make_input((3, 6))), 675 desc='no_batch_dim', 676 reference_fn=no_batch_dim_reference_fn), 677 ModuleInput(constructor_input=FunctionInput(2), 678 forward_input=FunctionInput(make_input((2, 3, 6)))), 679 ModuleInput(constructor_input=FunctionInput((2,), (2,)), 680 forward_input=FunctionInput(make_input((2, 3, 6))), 681 desc='stride'), 682 ModuleInput(constructor_input=FunctionInput(2, 2, 1), 683 forward_input=FunctionInput(make_input((2, 3, 6))), 684 desc='stride_pad')] 685 686 687def module_inputs_torch_nn_AvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs): 688 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 689 690 return [ 691 ModuleInput(constructor_input=FunctionInput((2, 2)), 692 forward_input=FunctionInput(make_input((3, 6, 6))), 693 desc='no_batch_dim', 694 reference_fn=no_batch_dim_reference_fn), 695 ModuleInput(constructor_input=FunctionInput((2, 2)), 696 forward_input=FunctionInput(make_input((2, 3, 6, 6)))), 697 ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2)), 698 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 699 desc='stride'), 700 ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1)), 701 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 702 desc='stride_pad'), 703 ModuleInput(constructor_input=FunctionInput((2, 2), divisor_override=1), 704 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 705 desc='divisor'), 706 ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), divisor_override=1), 707 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 708 desc='divisor_stride'), 709 ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1), divisor_override=1), 710 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 711 desc='divisor_stride_pad')] 712 713 714 715def module_inputs_torch_nn_AvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs): 716 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 717 718 return [ 719 ModuleInput(constructor_input=FunctionInput((2, 2, 2)), 720 forward_input=FunctionInput(make_input((3, 4, 4, 4))), 721 desc='no_batch_dim', 722 reference_fn=no_batch_dim_reference_fn), 723 ModuleInput(constructor_input=FunctionInput((2, 2, 2)), 724 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))), 725 ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2)), 726 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 727 desc='stride'), 728 ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)), 729 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 730 desc='stride_pad'), 731 ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1)), 732 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 733 desc='stride_pad_gpu_fixedkw_output'), 734 ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2)), 735 forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))), 736 desc='stride_pad_gpu_general_output'), 737 ModuleInput(constructor_input=FunctionInput(3, 1, 0), 738 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 739 desc='stride1_pad0_gpu_input'), 740 ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)), 741 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 742 desc='stride_pad_gpu_input_nooverlap'), 743 ModuleInput(constructor_input=FunctionInput((2, 2, 2), divisor_override=1), 744 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 745 desc='divisor'), 746 ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2), divisor_override=1), 747 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 748 desc='divisor_stride'), 749 ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1), 750 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 751 desc='divisor_stride_pad'), 752 ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1), divisor_override=1), 753 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 754 desc='divisor_stride_pad_gpu_fixedkw_output'), 755 ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2), divisor_override=1), 756 forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))), 757 desc='divisor_stride_pad_gpu_general_output'), 758 ModuleInput(constructor_input=FunctionInput(3, 1, 0, divisor_override=1), 759 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 760 desc='divisor_stride1_pad0_gpu_input'), 761 ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1), 762 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 763 desc='divisor_stride_pad_gpu_input_nooverlap')] 764 765 766 767def module_inputs_torch_nn_AdaptiveAvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs): 768 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 769 770 return [ 771 ModuleInput(constructor_input=FunctionInput(3,), 772 forward_input=FunctionInput(make_input((1, 3, 5))), 773 desc='single'), 774 ModuleInput(constructor_input=FunctionInput(3,), 775 forward_input=FunctionInput(make_input((3, 5))), 776 reference_fn=no_batch_dim_reference_fn, 777 desc='no_batch_dim'), 778 ModuleInput(constructor_input=FunctionInput(1,), 779 forward_input=FunctionInput(make_input((1, 3, 5))), 780 desc='one_output')] 781 782 783def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs): 784 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 785 786 return [ 787 ModuleInput(constructor_input=FunctionInput(3,), 788 forward_input=FunctionInput(make_input((1, 3, 5, 6))), 789 desc='single'), 790 ModuleInput(constructor_input=FunctionInput(3,), 791 forward_input=FunctionInput(make_input((3, 5, 6))), 792 reference_fn=no_batch_dim_reference_fn, 793 desc='no_batch_dim'), 794 ModuleInput(constructor_input=FunctionInput(1,), 795 forward_input=FunctionInput(make_input((1, 3, 5, 6))), 796 desc='single_1x1output'), 797 ModuleInput(constructor_input=FunctionInput((3, 4)), 798 forward_input=FunctionInput(make_input((1, 3, 5, 6))), 799 desc='tuple'), 800 ModuleInput(constructor_input=FunctionInput((3, None)), 801 forward_input=FunctionInput(make_input((1, 3, 5, 6))), 802 desc='tuple_none')] 803 804def module_inputs_torch_nn_AdaptiveAvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs): 805 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 806 807 return [ 808 ModuleInput(constructor_input=FunctionInput(3,), 809 forward_input=FunctionInput(make_input((2, 3, 5, 2, 7))), 810 desc='single'), 811 ModuleInput(constructor_input=FunctionInput(3,), 812 forward_input=FunctionInput(make_input((3, 5, 2, 7))), 813 reference_fn=no_batch_dim_reference_fn, 814 desc='no_batch_dim'), 815 ModuleInput(constructor_input=FunctionInput((3, 4, 5)), 816 forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))), 817 desc='tuple'), 818 ModuleInput(constructor_input=FunctionInput((None, 4, 5)), 819 forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))), 820 desc='tuple_none'), 821 ModuleInput(constructor_input=FunctionInput((3, 2, 2)), 822 forward_input=FunctionInput(make_input((1, 1, 3, 2, 6))), 823 desc='last_dim')] 824 825 826def module_inputs_torch_nn_AdaptiveMaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs): 827 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 828 829 return [ 830 ModuleInput(constructor_input=FunctionInput(3,), 831 forward_input=FunctionInput(make_input((1, 3, 5))), 832 desc='single'), 833 ModuleInput(constructor_input=FunctionInput(3,), 834 forward_input=FunctionInput(make_input((3, 5))), 835 reference_fn=no_batch_dim_reference_fn, 836 desc='no_batch_dim')] 837 838 839def module_inputs_torch_nn_AdaptiveMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs): 840 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 841 842 return [ 843 ModuleInput(constructor_input=FunctionInput(3,), 844 forward_input=FunctionInput(make_input((1, 3, 5, 6))), 845 desc='single'), 846 ModuleInput(constructor_input=FunctionInput(3,), 847 forward_input=FunctionInput(make_input((3, 5, 6))), 848 reference_fn=no_batch_dim_reference_fn, 849 desc='no_batch_dim'), 850 ModuleInput(constructor_input=FunctionInput((3, 4)), 851 forward_input=FunctionInput(make_input((1, 3, 5, 6))), 852 desc='tuple'), 853 ModuleInput(constructor_input=FunctionInput((3, None)), 854 forward_input=FunctionInput(make_input((1, 3, 5, 6))), 855 desc='tuple_none')] 856 857 858def module_inputs_torch_nn_AdaptiveMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs): 859 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 860 861 return [ 862 ModuleInput(constructor_input=FunctionInput(3,), 863 forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))), 864 desc='single'), 865 ModuleInput(constructor_input=FunctionInput(3,), 866 forward_input=FunctionInput(make_input((3, 5, 6, 7))), 867 reference_fn=no_batch_dim_reference_fn, 868 desc='no_batch_dim'), 869 ModuleInput(constructor_input=FunctionInput((3, 4, 5)), 870 forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))), 871 desc='tuple'), 872 ModuleInput(constructor_input=FunctionInput((3, None, 5)), 873 forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))), 874 desc='tuple_none'), 875 ModuleInput(constructor_input=FunctionInput(3), 876 forward_input=FunctionInput(make_input((2, 3, 12, 9, 3))), 877 desc='single_nonatomic'), 878 ModuleInput(constructor_input=FunctionInput((3, 4, 5)), 879 forward_input=FunctionInput(make_input((2, 3, 6, 4, 10))), 880 desc='tuple_nonatomic')] 881 882 883def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad, training, **kwargs): 884 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 885 886 return [ 887 ModuleInput(constructor_input=FunctionInput(10,), 888 forward_input=FunctionInput(make_input((4, 10))), 889 desc='affine'), 890 ModuleInput(constructor_input=FunctionInput(5,), 891 forward_input=FunctionInput(make_input((4, 5, 3))), 892 desc='3d_input'), 893 ModuleInput(constructor_input=FunctionInput(10, 1e-3, None), 894 forward_input=FunctionInput(make_input((4, 10))), 895 desc='affine_simple_average'), 896 ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, False), 897 forward_input=FunctionInput(make_input((4, 10))), 898 desc='not_affine'), 899 ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, True, False), 900 forward_input=FunctionInput(make_input((4, 10))), 901 desc='not_tracking_stats'), 902 ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), 903 forward_input=FunctionInput(make_input((4, 5, 3))), 904 desc='3d_input_not_affine'), 905 ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), 906 forward_input=FunctionInput(make_input((0, 5, 9))), 907 desc='zero_batch')] 908 909 910def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs): 911 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 912 913 return [ 914 ModuleInput(constructor_input=FunctionInput(3,), 915 forward_input=FunctionInput(make_input((2, 3, 6, 6)))), 916 ModuleInput(constructor_input=FunctionInput(3, 1e-3, None), 917 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 918 desc='2d_simple_average'), 919 ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8), 920 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 921 desc='momentum'), 922 ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, False), 923 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 924 desc='not_affine'), 925 ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, True, False), 926 forward_input=FunctionInput(make_input((2, 3, 6, 6))), 927 desc='not_tracking_stats'), 928 ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), 929 forward_input=FunctionInput(make_input((0, 5, 2, 2))), 930 desc='zero_batch')] 931 932 933def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs): 934 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 935 936 return [ 937 ModuleInput(constructor_input=FunctionInput(3,), 938 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))), 939 ModuleInput(constructor_input=FunctionInput(3, 1e-3, None), 940 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 941 desc='3d_simple_average'), 942 ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7), 943 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 944 desc='momentum'), 945 ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, False), 946 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 947 desc='not_affine'), 948 ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, True, False), 949 forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), 950 desc='not_tracking_stats'), 951 ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), 952 forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))), 953 desc='zero_batch')] 954 955 956def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs): 957 N = kwargs['N'] 958 lazy = kwargs.get('lazy', False) 959 transposed = kwargs.get('transposed', False) 960 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 961 conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}] 962 kernel_size, C_in, C_out = 3, 4, 5 963 input_no_batch_shape = (C_in,) + tuple(i + 3 for i in range(N)) 964 input_batch_shape = (2,) + input_no_batch_shape 965 return [ 966 ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else 967 FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)), 968 forward_input=FunctionInput(make_input( 969 input_batch_shape if with_batch else input_no_batch_shape)), 970 desc=('' if with_batch else 'no_batch_dim'), 971 reference_fn=(None if with_batch else no_batch_dim_reference_fn)) 972 for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list) 973 ] 974 975 976def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs): 977 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 978 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 979 980 cases: List[Tuple[str, dict]] = [ 981 ('', {}), 982 ('reduction_sum', {'reduction': 'sum'}), 983 ('reduction_mean', {'reduction': 'mean'}), 984 ('reduction_none', {'reduction': 'none'}), 985 ('margin', {'margin': 0.7}) 986 ] 987 988 module_inputs = [] 989 for desc, constructor_kwargs in cases: 990 def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs): 991 return cosineembeddingloss_reference(i1, i2, t, **constructor_kwargs) 992 993 module_inputs.append( 994 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 995 forward_input=FunctionInput(make_input((15, 10)), make_input((15, 10)), 996 make_target((15,)).sign()), 997 desc=desc, 998 reference_fn=reference_fn) 999 ) 1000 1001 return module_inputs 1002 1003 1004def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs): 1005 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1006 1007 return [ 1008 ModuleInput(constructor_input=FunctionInput(alpha=2.), 1009 forward_input=FunctionInput(make_input((3, 2, 5))), 1010 reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))), 1011 ModuleInput(constructor_input=FunctionInput(alpha=2.), 1012 forward_input=FunctionInput(make_input(())), 1013 desc='scalar'), 1014 ModuleInput(constructor_input=FunctionInput(), 1015 forward_input=FunctionInput(make_input((3,))), 1016 desc='no_batch_dim', 1017 reference_fn=no_batch_dim_reference_fn), 1018 ModuleInput(constructor_input=FunctionInput(alpha=2.), 1019 forward_input=FunctionInput(make_input((2, 3, 2, 5))), 1020 desc='4d_input')] 1021 1022 1023def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs): 1024 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1025 1026 return [ 1027 ModuleInput(constructor_input=FunctionInput(alpha=2.), 1028 forward_input=FunctionInput(make_input((3, 2, 5))), 1029 reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))), 1030 ModuleInput(constructor_input=FunctionInput(alpha=2.), 1031 forward_input=FunctionInput(make_input(())), 1032 reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1)), 1033 desc='scalar'), 1034 ModuleInput(constructor_input=FunctionInput(alpha=2.), 1035 forward_input=FunctionInput(make_input((3,))), 1036 desc='no_batch_dim', 1037 reference_fn=no_batch_dim_reference_fn)] 1038 1039 1040def module_inputs_torch_nn_GLU(module_info, device, dtype, requires_grad, training, **kwargs): 1041 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1042 1043 return [ 1044 ModuleInput(constructor_input=FunctionInput(), 1045 forward_input=FunctionInput(make_input((5, 6)))), 1046 ModuleInput(constructor_input=FunctionInput(1), 1047 forward_input=FunctionInput(make_input((5, 6, 7))), 1048 desc='dim'), 1049 ModuleInput(constructor_input=FunctionInput(), 1050 forward_input=FunctionInput(make_input((4,))), 1051 desc='no_batch_dim', 1052 reference_fn=no_batch_dim_reference_fn)] 1053 1054 1055def module_inputs_torch_nn_GELU(module_info, device, dtype, requires_grad, training, **kwargs): 1056 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1057 1058 return [ 1059 ModuleInput(constructor_input=FunctionInput('none'), 1060 forward_input=FunctionInput(make_input(())), 1061 reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))), 1062 desc='scalar'), 1063 ModuleInput(constructor_input=FunctionInput('none'), 1064 forward_input=FunctionInput(make_input((3, 2, 5))), 1065 reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))), 1066 ModuleInput(constructor_input=FunctionInput(), 1067 forward_input=FunctionInput(make_input((3,))), 1068 desc='no_batch_dim', 1069 reference_fn=no_batch_dim_reference_fn)] 1070 1071 1072def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs): 1073 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1074 1075 return [ 1076 ModuleInput(constructor_input=FunctionInput(), 1077 forward_input=FunctionInput(make_input(())), 1078 desc='scalar'), 1079 ModuleInput(constructor_input=FunctionInput(), 1080 forward_input=FunctionInput(make_input(4)), 1081 reference_fn=no_batch_dim_reference_fn, 1082 desc='no_batch_dim'), 1083 ModuleInput(constructor_input=FunctionInput(), 1084 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 1085 desc='channels_last_mem_format'), 1086 ModuleInput(constructor_input=FunctionInput(), 1087 forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))), 1088 desc='channels_last_3d_mem_format')] 1089 1090 1091def module_inputs_torch_nn_ReLU6(module_info, device, dtype, requires_grad, training, **kwargs): 1092 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1093 1094 return [ 1095 ModuleInput(constructor_input=FunctionInput(), 1096 forward_input=FunctionInput(make_input(())), 1097 desc='scalar'), 1098 ModuleInput(constructor_input=FunctionInput(), 1099 forward_input=FunctionInput(make_input(4)), 1100 reference_fn=no_batch_dim_reference_fn, 1101 desc='no_batch_dim'), 1102 ModuleInput(constructor_input=FunctionInput(), 1103 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 1104 desc='channels_last_mem_format'), 1105 ModuleInput(constructor_input=FunctionInput(), 1106 forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))), 1107 desc='channels_last_3d_mem_format')] 1108 1109 1110def module_inputs_torch_nn_LeakyReLU(module_info, device, dtype, requires_grad, training, **kwargs): 1111 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1112 1113 return [ 1114 ModuleInput(constructor_input=FunctionInput(), 1115 forward_input=FunctionInput(make_input((3, 2, 5)))), 1116 ModuleInput(constructor_input=FunctionInput(), 1117 forward_input=FunctionInput(make_input(4)), 1118 reference_fn=no_batch_dim_reference_fn, 1119 desc='no_batch_dim'), 1120 ModuleInput(constructor_input=FunctionInput(0.5), 1121 forward_input=FunctionInput(make_input((3, 2, 5))), 1122 desc='with_negval'), 1123 ModuleInput(constructor_input=FunctionInput(0.0), 1124 forward_input=FunctionInput(make_input((10, 10))), 1125 desc='with_zero_negval'), 1126 ModuleInput(constructor_input=FunctionInput(0.5), 1127 forward_input=FunctionInput(make_input(())), 1128 desc='with_negval_scalar')] 1129 1130 1131def module_inputs_torch_nn_PReLU(module_info, device, dtype, requires_grad, training, **kwargs): 1132 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1133 1134 return [ 1135 ModuleInput(constructor_input=FunctionInput(), 1136 forward_input=FunctionInput(make_input(())), 1137 desc='scalar'), 1138 ModuleInput(constructor_input=FunctionInput(), 1139 forward_input=FunctionInput(make_input(4)), 1140 reference_fn=no_batch_dim_reference_fn, 1141 desc='no_batch_dim'), 1142 ModuleInput(constructor_input=FunctionInput(), 1143 forward_input=FunctionInput(make_input((2, 3, 4))), 1144 reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], 1145 desc='1d'), 1146 ModuleInput(constructor_input=FunctionInput(3), 1147 forward_input=FunctionInput(make_input((2, 3, 4))), 1148 reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], 1149 desc='1d_multiparam'), 1150 ModuleInput(constructor_input=FunctionInput(), 1151 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 1152 reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], 1153 desc='2d'), 1154 ModuleInput(constructor_input=FunctionInput(3), 1155 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 1156 reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], 1157 desc='2d_multiparam'), 1158 ModuleInput(constructor_input=FunctionInput(), 1159 forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))), 1160 reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], 1161 desc='3d'), 1162 ModuleInput(constructor_input=FunctionInput(3), 1163 forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))), 1164 reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0], 1165 desc='3d_multiparam')] 1166 1167 1168def module_inputs_torch_nn_SELU(module_info, device, dtype, requires_grad, training, **kwargs): 1169 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1170 1171 return [ 1172 ModuleInput(constructor_input=FunctionInput(), 1173 forward_input=FunctionInput(make_input((3, 2, 5)))), 1174 ModuleInput(constructor_input=FunctionInput(), 1175 forward_input=FunctionInput(make_input(4)), 1176 reference_fn=no_batch_dim_reference_fn, 1177 desc='no_batch_dim'), 1178 ModuleInput(constructor_input=FunctionInput(), 1179 forward_input=FunctionInput(make_input(())), 1180 desc='scalar')] 1181 1182 1183def module_inputs_torch_nn_SiLU(module_info, device, dtype, requires_grad, training, **kwargs): 1184 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1185 1186 return [ 1187 ModuleInput(constructor_input=FunctionInput(), 1188 forward_input=FunctionInput(make_input(())), 1189 reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x), 1190 desc='scalar'), 1191 ModuleInput(constructor_input=FunctionInput(), 1192 forward_input=FunctionInput(make_input(4)), 1193 reference_fn=no_batch_dim_reference_fn, 1194 desc='no_batch_dim'), 1195 ModuleInput(constructor_input=FunctionInput(), 1196 forward_input=FunctionInput(make_input((5, 6, 7))), 1197 reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x))] 1198 1199 1200def module_inputs_torch_nn_Softmax(module_info, device, dtype, requires_grad, training, **kwargs): 1201 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1202 1203 return [ 1204 ModuleInput(constructor_input=FunctionInput(1), 1205 forward_input=FunctionInput(make_input((10, 20))), 1206 reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20))), 1207 ModuleInput(constructor_input=FunctionInput(0), 1208 forward_input=FunctionInput(make_input(())), 1209 reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(0, True)), 1210 desc='scalar'), 1211 ModuleInput(constructor_input=FunctionInput(-1), 1212 forward_input=FunctionInput(make_input((4, 5))), 1213 reference_fn=no_batch_dim_reference_fn, 1214 desc='no_batch_dim')] 1215 1216 1217def module_inputs_torch_nn_Softmax2d(module_info, device, dtype, requires_grad, training, **kwargs): 1218 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1219 1220 return [ 1221 ModuleInput(constructor_input=FunctionInput(), 1222 forward_input=FunctionInput(make_input((1, 3, 10, 20))), 1223 reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, False))), 1224 ModuleInput(constructor_input=FunctionInput(), 1225 forward_input=FunctionInput(make_input((3, 4, 5))), 1226 reference_fn=no_batch_dim_reference_fn, 1227 desc='no_batch_dim')] 1228 1229 1230def module_inputs_torch_nn_LogSoftmax(module_info, device, dtype, requires_grad, training, **kwargs): 1231 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1232 1233 return [ 1234 ModuleInput(constructor_input=FunctionInput(1), 1235 forward_input=FunctionInput(make_input((10, 20))), 1236 reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_()), 1237 ModuleInput(constructor_input=FunctionInput(1), 1238 forward_input=FunctionInput(make_input((1, 3, 10, 20))), 1239 reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(), 1240 desc='multiparam'), 1241 ModuleInput(constructor_input=FunctionInput(0), 1242 forward_input=FunctionInput(make_input(())), 1243 reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(), 1244 desc='multiparam_scalar'), 1245 ModuleInput(constructor_input=FunctionInput(-1), 1246 forward_input=FunctionInput(make_input((4, 5))), 1247 reference_fn=no_batch_dim_reference_fn, 1248 desc='no_batch_dim')] 1249 1250 1251def module_inputs_torch_nn_Softmin(module_info, device, dtype, requires_grad, training, **kwargs): 1252 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1253 1254 return [ 1255 ModuleInput(constructor_input=FunctionInput(1), 1256 forward_input=FunctionInput(make_input((10, 20)))), 1257 ModuleInput(constructor_input=FunctionInput(1), 1258 forward_input=FunctionInput(make_input((2, 3, 5, 10))), 1259 desc='multidim'), 1260 ModuleInput(constructor_input=FunctionInput(0), 1261 forward_input=FunctionInput(make_input(())), 1262 desc='scalar'), 1263 ModuleInput(constructor_input=FunctionInput(-1), 1264 forward_input=FunctionInput(make_input((3, 4, 10))), 1265 reference_fn=no_batch_dim_reference_fn, 1266 desc='no_batch_dim')] 1267 1268 1269def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, training, **kwargs): 1270 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1271 1272 return [ 1273 ModuleInput(constructor_input=FunctionInput(), 1274 forward_input=FunctionInput(make_input((10, 20))), 1275 reference_fn=lambda m, p, i: torch.log(1 + torch.exp(i))), 1276 ModuleInput(constructor_input=FunctionInput(2), 1277 forward_input=FunctionInput(make_input((10, 20))), 1278 reference_fn=lambda m, p, i: 1. / 2. * torch.log(1 + torch.exp(2 * i)), 1279 desc='beta'), 1280 ModuleInput(constructor_input=FunctionInput(2, -100), 1281 forward_input=FunctionInput(make_input((10, 20))), 1282 reference_fn=( 1283 lambda m, p, i: ((i * 2) > -100).type_as(i) * i 1284 + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))), 1285 desc='beta_threshold'), 1286 ModuleInput(constructor_input=FunctionInput(2, -100), 1287 forward_input=FunctionInput(make_input(())), 1288 reference_fn=( 1289 lambda m, p, i: ((i * 2) > -100).type_as(i) * i 1290 + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))), 1291 desc='beta_threshold_scalar'), 1292 ModuleInput(constructor_input=FunctionInput(), 1293 forward_input=FunctionInput(make_input(4)), 1294 reference_fn=no_batch_dim_reference_fn, 1295 desc='no_batch_dim')] 1296 1297 1298def module_inputs_torch_nn_Softshrink(module_info, device, dtype, requires_grad, training, **kwargs): 1299 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1300 1301 return [ 1302 ModuleInput(constructor_input=FunctionInput(), 1303 forward_input=FunctionInput(make_input((3, 2, 5)))), 1304 ModuleInput(constructor_input=FunctionInput(1,), 1305 forward_input=FunctionInput(make_input((3, 2, 5))), 1306 desc='lambda'), 1307 ModuleInput(constructor_input=FunctionInput(1,), 1308 forward_input=FunctionInput(make_input(())), 1309 desc='lambda_scalar'), 1310 ModuleInput(constructor_input=FunctionInput(), 1311 forward_input=FunctionInput(make_input(4)), 1312 reference_fn=no_batch_dim_reference_fn, 1313 desc='no_batch_dim')] 1314 1315 1316def module_inputs_torch_nn_Softsign(module_info, device, dtype, requires_grad, training, **kwargs): 1317 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1318 1319 return [ 1320 ModuleInput(constructor_input=FunctionInput(), 1321 forward_input=FunctionInput(make_input((3, 2, 5))), 1322 reference_fn=lambda m, p, i: i.div(1 + torch.abs(i))), 1323 ModuleInput(constructor_input=FunctionInput(), 1324 forward_input=FunctionInput(make_input(())), 1325 reference_fn=lambda m, p, i: i.div(1 + torch.abs(i)), 1326 desc='scalar'), 1327 ModuleInput(constructor_input=FunctionInput(), 1328 forward_input=FunctionInput(make_input(4)), 1329 reference_fn=no_batch_dim_reference_fn, 1330 desc='no_batch_dim')] 1331 1332 1333def module_inputs_torch_nn_Tanh(module_info, device, dtype, requires_grad, training, **kwargs): 1334 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1335 1336 return [ 1337 ModuleInput(constructor_input=FunctionInput(), 1338 forward_input=FunctionInput(make_input((2, 3, 4, 5)))), 1339 ModuleInput(constructor_input=FunctionInput(), 1340 forward_input=FunctionInput(make_input(())), 1341 desc='scalar'), 1342 ModuleInput(constructor_input=FunctionInput(), 1343 forward_input=FunctionInput(make_input(4)), 1344 reference_fn=no_batch_dim_reference_fn, 1345 desc='no_batch_dim')] 1346 1347 1348 1349def module_inputs_torch_nn_Tanhshrink(module_info, device, dtype, requires_grad, training, **kwargs): 1350 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1351 1352 return [ 1353 ModuleInput(constructor_input=FunctionInput(), 1354 forward_input=FunctionInput(make_input((2, 3, 4, 5)))), 1355 ModuleInput(constructor_input=FunctionInput(), 1356 forward_input=FunctionInput(make_input(())), 1357 desc='scalar'), 1358 ModuleInput(constructor_input=FunctionInput(), 1359 forward_input=FunctionInput(make_input(4)), 1360 reference_fn=no_batch_dim_reference_fn, 1361 desc='no_batch_dim')] 1362 1363 1364def module_inputs_torch_nn_Threshold(module_info, device, dtype, requires_grad, training, **kwargs): 1365 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1366 1367 return [ 1368 ModuleInput(constructor_input=FunctionInput(2., 1.), 1369 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 1370 desc='threshold_value'), 1371 ModuleInput(constructor_input=FunctionInput(2., 10.), 1372 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 1373 desc='large_value'), 1374 ModuleInput(constructor_input=FunctionInput(2., 1.), 1375 forward_input=FunctionInput(make_input(())), 1376 desc='threshold_value_scalar'), 1377 ModuleInput(constructor_input=FunctionInput(2., 1.), 1378 forward_input=FunctionInput(make_input(4)), 1379 reference_fn=no_batch_dim_reference_fn, 1380 desc='no_batch_dim')] 1381 1382 1383def module_inputs_torch_nn_Mish(module_info, device, dtype, requires_grad, training, **kwargs): 1384 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1385 1386 return [ 1387 ModuleInput(constructor_input=FunctionInput(), 1388 forward_input=FunctionInput(make_input((5, 6, 7))), 1389 reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i))), 1390 ModuleInput(constructor_input=FunctionInput(), 1391 forward_input=FunctionInput(make_input(())), 1392 reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i)), 1393 desc='scalar'), 1394 ModuleInput(constructor_input=FunctionInput(), 1395 forward_input=FunctionInput(make_input(4)), 1396 reference_fn=no_batch_dim_reference_fn, 1397 desc='no_batch_dim')] 1398 1399 1400def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs): 1401 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1402 1403 return [ 1404 ModuleInput(constructor_input=FunctionInput(), 1405 forward_input=FunctionInput(make_input((2, 3, 4)), 1406 make_input((2, 3, 4))), 1407 reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum() 1408 for a, b in zip(i, t))), 1409 ModuleInput(constructor_input=FunctionInput(), 1410 forward_input=FunctionInput(make_input(()), make_input(())), 1411 reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(), 1412 desc='scalar')] + generate_regression_criterion_inputs(make_input) 1413 1414 1415def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_grad, training, **kwargs): 1416 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1417 1418 1419 cases: List[Tuple[str, dict]] = [ 1420 ('', {}), 1421 ('reduction_sum', {'reduction': 'sum'}), 1422 ('reduction_mean', {'reduction': 'mean'}), 1423 ('reduction_none', {'reduction': 'none'}), 1424 ] 1425 1426 module_inputs = [] 1427 for desc, constructor_kwargs in cases: 1428 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 1429 return smoothl1loss_reference(i, t, **constructor_kwargs) 1430 1431 module_inputs.append( 1432 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 1433 forward_input=FunctionInput(make_input((5, 10)), 1434 make_input((5, 10))), 1435 desc=desc, 1436 reference_fn=reference_fn) 1437 ) 1438 module_inputs.append( 1439 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 1440 forward_input=FunctionInput(make_input(()), 1441 make_input(())), 1442 desc=f'scalar_{desc}', 1443 reference_fn=reference_fn) 1444 ) 1445 1446 return module_inputs 1447 1448 1449 1450def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs): 1451 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1452 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 1453 make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 1454 1455 cases: List[Tuple[str, dict]] = [ 1456 ('', {}), 1457 ('reduction_sum', {'reduction': 'sum'}), 1458 ('reduction_mean', {'reduction': 'mean'}), 1459 ('reduction_none', {'reduction': 'none'}), 1460 ('weights', {'weight': make_weight((10,))}), 1461 ] 1462 1463 def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None): 1464 result = -(t * i.log() + (1 - t) * (1 - i).log()) 1465 1466 if weight is not None: 1467 result = result * weight 1468 1469 if reduction == 'none': 1470 return result 1471 elif reduction == 'mean': 1472 return result.sum() / i.numel() 1473 else: 1474 return result.sum() 1475 1476 module_inputs = [] 1477 for desc, constructor_kwargs in cases: 1478 module_inputs.append( 1479 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 1480 forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2), 1481 make_target((15, 10)).gt(0).to(dtype)), 1482 desc=desc, 1483 reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs)) 1484 ) 1485 1486 scalar_weight = make_weight(()) 1487 module_inputs.append( 1488 ModuleInput(constructor_input=FunctionInput(weight=scalar_weight), 1489 forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2), 1490 make_target(()).gt(0).to(dtype)), 1491 desc='scalar_weight', 1492 reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight)) 1493 ) 1494 1495 return module_inputs 1496 1497 1498def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs): 1499 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1500 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 1501 make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 1502 1503 cases: List[Tuple[str, dict]] = [ 1504 ('', {}), 1505 ('reduction_sum', {'reduction': 'sum'}), 1506 ('reduction_mean', {'reduction': 'mean'}), 1507 ('reduction_none', {'reduction': 'none'}), 1508 ('weights', {'weight': make_weight((10,))}), 1509 ('scalar_weights', {'weight': make_weight(())}) 1510 ] 1511 1512 def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None): 1513 # TODO: add pos_weight to the definition here and corresponding SampleInputs 1514 max_val = (-i).clamp(min=0) 1515 result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_()) 1516 1517 if weight is not None: 1518 result = result * weight 1519 1520 if reduction == 'none': 1521 return result 1522 elif reduction == 'mean': 1523 return result.sum() / i.numel() 1524 else: 1525 return result.sum() 1526 1527 module_inputs = [] 1528 for desc, constructor_kwargs in cases: 1529 module_inputs.append( 1530 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 1531 forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2), 1532 make_target((15, 10)).gt(0).to(dtype)), 1533 desc=desc, 1534 reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs)) 1535 ) 1536 1537 return module_inputs 1538 1539 1540def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs): 1541 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1542 make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) 1543 make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 1544 1545 reductions: List[str] = ['mean', 'sum', 'none'] 1546 cases: List[Tuple[str, dict]] = [ 1547 ('', {}), 1548 ('weights', {'weight': make_weight((3,))}), 1549 ('ignore_index', {'ignore_index': 1}), 1550 ('label_smoothing', {'label_smoothing': 0.15}), 1551 ('ignore_index_label_smoothing', {'ignore_index': 1, 'label_smoothing': 0.15}) 1552 ] 1553 1554 module_inputs = [] 1555 for reduction, (desc, constructor_kwargs) in product(reductions, cases): 1556 def reference_fn(m, p, i, t, reduction=reduction, constructor_kwargs=constructor_kwargs): 1557 return cross_entropy_loss_reference(i, t, reduction=reduction, **constructor_kwargs) 1558 1559 module_inputs.append( 1560 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1561 forward_input=FunctionInput(make_input((2, 3, 5, 5)), 1562 make_target((2, 5, 5), low=0, high=3)), 1563 desc=f"4d_{desc}_{reduction}", 1564 reference_fn=reference_fn) 1565 ) 1566 module_inputs.append( 1567 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1568 forward_input=FunctionInput(make_input((2, 3, 5)), 1569 make_target((2, 5), low=0, high=3)), 1570 desc=f"3d_{desc}_{reduction}", 1571 reference_fn=reference_fn) 1572 ) 1573 module_inputs.append( 1574 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1575 forward_input=FunctionInput(make_input((2, 3)), 1576 make_target((2), low=0, high=3)), 1577 desc=f"2d_{desc}_{reduction}", 1578 reference_fn=reference_fn) 1579 ) 1580 module_inputs.append( 1581 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1582 forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)), 1583 make_target((2, 5, 5, 2, 2), low=0, high=3)), 1584 desc=f"higher_dim_{desc}_{reduction}", 1585 reference_fn=reference_fn) 1586 ) 1587 1588 if constructor_kwargs.get('ignore_index', None) is None: 1589 module_inputs.append( 1590 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1591 forward_input=FunctionInput(make_input((5, 3, 4, 2)), 1592 make_input((5, 3, 4, 2)).softmax(dim=1)), 1593 desc=f"4d_prob_target_{desc}_{reduction}", 1594 reference_fn=reference_fn) 1595 ) 1596 module_inputs.append( 1597 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1598 forward_input=FunctionInput(make_input((5, 3, 4)), 1599 make_input((5, 3, 4)).softmax(dim=1)), 1600 desc=f"3d_prob_target_{desc}_{reduction}", 1601 reference_fn=reference_fn) 1602 ) 1603 module_inputs.append( 1604 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1605 forward_input=FunctionInput(make_input((5, 3)), 1606 make_input((5, 3)).softmax(dim=1)), 1607 desc=f"2d_prob_target_{desc}_{reduction}", 1608 reference_fn=reference_fn) 1609 ) 1610 module_inputs.append( 1611 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1612 forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)), 1613 make_input((2, 3, 5, 5, 2, 2)).softmax(dim=1)), 1614 desc=f"higher_dim_prob_target_{desc}_{reduction}", 1615 reference_fn=reference_fn) 1616 ) 1617 module_inputs.append( 1618 ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs), 1619 forward_input=FunctionInput(make_input((3,)), 1620 make_target((), low=0, high=3)), 1621 desc=f"no_batch_dim_{desc}_{reduction}", 1622 reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True)) 1623 ) 1624 1625 return module_inputs 1626 1627 1628 1629def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, training, **kwargs): 1630 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1631 make_target = partial(make_tensor, device=device, requires_grad=False) 1632 1633 cases: List[Tuple[str, dict]] = [ 1634 ('', {}), 1635 ('reduction_sum', {'reduction': 'sum'}), 1636 ('reduction_mean', {'reduction': 'mean'}), 1637 ('reduction_none', {'reduction': 'none'}), 1638 ('blank', {'blank': 14}) 1639 ] 1640 target_dtypes = [torch.int, torch.long] 1641 1642 module_inputs = [] 1643 for target_dtype, (desc, constructor_kwargs) in product(target_dtypes, cases): 1644 def reference_fn(m, p, i, t, il, tl, constructor_kwargs=constructor_kwargs): 1645 return ctcloss_reference(i, t, il, tl, **constructor_kwargs) 1646 1647 blank = constructor_kwargs.get('blank', 0) 1648 low = 0 if blank == 14 else 1 1649 high = 14 if blank == 14 else 15 1650 1651 module_inputs.append( 1652 ModuleInput( 1653 constructor_input=FunctionInput(**constructor_kwargs), 1654 forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), 1655 make_target((3, 30), dtype=target_dtype, low=low, high=high), 1656 (50, 50, 50), (30, 25, 20)), 1657 desc=f'{desc}_lengths_intlists', 1658 reference_fn=reference_fn) 1659 ) 1660 module_inputs.append( 1661 ModuleInput( 1662 constructor_input=FunctionInput(**constructor_kwargs), 1663 forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), 1664 make_target((3, 30), dtype=target_dtype, low=low, high=high), 1665 torch.tensor((50, 50, 50), device=device), 1666 torch.tensor((30, 25, 20), device=device)), 1667 desc=f'{desc}_lengths_tensors', 1668 reference_fn=reference_fn) 1669 ) 1670 module_inputs.append( 1671 ModuleInput( 1672 constructor_input=FunctionInput(**constructor_kwargs), 1673 forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), 1674 make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high), 1675 (50, 50, 50), (30, 25, 20)), 1676 desc=f'{desc}_1d_target_lengths_intlists', 1677 reference_fn=reference_fn) 1678 ) 1679 module_inputs.append( 1680 ModuleInput( 1681 constructor_input=FunctionInput(**constructor_kwargs), 1682 forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2), 1683 make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high), 1684 torch.tensor((50, 50, 50), device=device), 1685 torch.tensor((30, 25, 20), device=device)), 1686 desc=f'{desc}_1d_target_lengths_tensors', 1687 reference_fn=reference_fn) 1688 ) 1689 1690 return module_inputs 1691 1692 1693def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs): 1694 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1695 1696 return [ 1697 ModuleInput( 1698 constructor_input=FunctionInput(3, 6, 1e-3), 1699 forward_input=FunctionInput(make_input((4, 6, 5))), 1700 desc='1d_affine'), 1701 ModuleInput( 1702 constructor_input=FunctionInput(3, 12, 1e-3), 1703 forward_input=FunctionInput(make_input((4, 12))), 1704 desc='1d_affine_GN'), 1705 ModuleInput( 1706 constructor_input=FunctionInput(1, 6, 1e-3), 1707 forward_input=FunctionInput(make_input((150, 6))), 1708 desc='1d_affine_large_batch'), 1709 ModuleInput( 1710 constructor_input=FunctionInput(5, 5, 1e-3, False), 1711 forward_input=FunctionInput(make_input((4, 5, 5))), 1712 desc='1d_no_affine_IN'), 1713 ModuleInput( 1714 constructor_input=FunctionInput(1, 10, 1e-3, False), 1715 forward_input=FunctionInput(make_input((4, 10))), 1716 desc='1d_no_affine_LN'), 1717 ModuleInput( 1718 constructor_input=FunctionInput(3, 6, 1e-3), 1719 forward_input=FunctionInput(make_input((4, 6, 2, 3))), 1720 desc='2d_affine'), 1721 ModuleInput( 1722 constructor_input=FunctionInput(3, 3, 1e-3, False), 1723 forward_input=FunctionInput(make_input((4, 3, 2, 3))), 1724 desc='2d_no_affine_IN'), 1725 ModuleInput( 1726 constructor_input=FunctionInput(1, 3, 1e-3, False), 1727 forward_input=FunctionInput(make_input((4, 3, 2, 3))), 1728 desc='2d_no_affine_LN'), 1729 ] 1730 1731 1732def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs): 1733 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1734 1735 return [ 1736 ModuleInput( 1737 constructor_input=FunctionInput(2.), 1738 forward_input=FunctionInput(make_input((4, 3, 2, 4))), 1739 ), 1740 ModuleInput( 1741 constructor_input=FunctionInput(2.), 1742 forward_input=FunctionInput(make_input(())), 1743 desc='scalar', 1744 ), 1745 ModuleInput( 1746 constructor_input=FunctionInput(), 1747 forward_input=FunctionInput(make_input(4)), 1748 reference_fn=no_batch_dim_reference_fn, 1749 desc='no_batch_dim', 1750 ) 1751 ] 1752 1753 1754def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs): 1755 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1756 1757 return [ 1758 ModuleInput( 1759 constructor_input=FunctionInput(), 1760 forward_input=FunctionInput(make_input(4)), 1761 reference_fn=no_batch_dim_reference_fn, 1762 desc='no_batch_dim', 1763 ), 1764 ModuleInput( 1765 constructor_input=FunctionInput(), 1766 forward_input=FunctionInput(make_input((2, 3, 2, 5))), 1767 desc='4d_input') 1768 ] 1769 1770 1771def module_inputs_torch_nn_Hardtanh(module_info, device, dtype, requires_grad, training, **kwargs): 1772 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1773 1774 return [ 1775 ModuleInput( 1776 constructor_input=FunctionInput(), 1777 forward_input=FunctionInput(make_input((3, 2, 5))), 1778 reference_fn=lambda m, p, i: i.clamp(-1, 1), 1779 ), 1780 ModuleInput( 1781 constructor_input=FunctionInput(), 1782 forward_input=FunctionInput(make_input(())), 1783 reference_fn=lambda m, p, i: i.clamp(-1, 1), 1784 desc='scalar', 1785 ), 1786 ModuleInput( 1787 constructor_input=FunctionInput(), 1788 forward_input=FunctionInput(make_input(4)), 1789 reference_fn=no_batch_dim_reference_fn, 1790 desc='no_batch_dim', 1791 ) 1792 ] 1793 1794 1795def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs): 1796 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1797 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 1798 1799 cases: List[Tuple[str, dict]] = [ 1800 ('', {}), 1801 ('reduction_sum', {'reduction': 'sum'}), 1802 ('reduction_mean', {'reduction': 'mean'}), 1803 ('reduction_none', {'reduction': 'none'}), 1804 ('margin', {'margin': 0.5}) 1805 ] 1806 1807 module_inputs = [] 1808 for desc, constructor_kwargs in cases: 1809 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 1810 return hingeembeddingloss_reference(i, t, **constructor_kwargs) 1811 1812 module_inputs.append( 1813 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 1814 forward_input=FunctionInput(make_input((10,)), 1815 make_target((10,)).gt(0).to(dtype).mul_(2).sub_(1)), 1816 desc=desc, 1817 reference_fn=reference_fn) 1818 ) 1819 module_inputs.append( 1820 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 1821 forward_input=FunctionInput(make_input(()), 1822 make_target(()).gt(0).to(dtype).mul_(2).sub_(1)), 1823 desc=f'scalar_{desc}', 1824 reference_fn=reference_fn) 1825 ) 1826 1827 return module_inputs 1828 1829 1830def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs): 1831 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1832 1833 cases: List[Tuple[str, dict]] = [ 1834 ('', {}), 1835 ('reduction_sum', {'reduction': 'sum'}), 1836 ('reduction_mean', {'reduction': 'mean'}), 1837 ('reduction_none', {'reduction': 'none'}), 1838 ] 1839 1840 module_inputs = [] 1841 for desc, constructor_kwargs in cases: 1842 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 1843 return huberloss_reference(i, t, **constructor_kwargs) 1844 1845 module_inputs.append( 1846 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 1847 forward_input=FunctionInput(make_input((5, 10)), 1848 make_input((5, 10))), 1849 desc=desc, 1850 reference_fn=reference_fn) 1851 ) 1852 1853 return module_inputs 1854 1855 1856def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_grad, training, **kwargs): 1857 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1858 lazy = kwargs.get('lazy', False) 1859 N = kwargs['N'] 1860 num_features, eps, momentum, affine, track_running_stats = 3, 1e-3, 0.3, False, True 1861 input_no_batch_shape_dict = {1: (3, 15), 2: (3, 6, 6), 3: (3, 4, 4, 4)} 1862 input_no_batch_shape = input_no_batch_shape_dict[N] 1863 input_batch_shape = (4,) + input_no_batch_shape 1864 1865 return [ 1866 ModuleInput( 1867 constructor_input=( 1868 FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum) 1869 ), 1870 forward_input=FunctionInput(make_input(input_batch_shape))), 1871 ModuleInput( 1872 constructor_input=( 1873 FunctionInput(eps, momentum, affine, track_running_stats) if lazy else 1874 FunctionInput(num_features, eps, momentum, affine, track_running_stats) 1875 ), 1876 forward_input=FunctionInput(make_input(input_batch_shape)), 1877 desc='tracking_stats'), 1878 ModuleInput( 1879 constructor_input=( 1880 FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum) 1881 ), 1882 forward_input=FunctionInput(make_input(input_no_batch_shape)), 1883 reference_fn=no_batch_dim_reference_fn, 1884 desc='tracking_stats_no_batch_dim'), 1885 ModuleInput( 1886 constructor_input=( 1887 FunctionInput(eps, momentum, affine, track_running_stats) if lazy else 1888 FunctionInput(num_features, eps, momentum, affine, track_running_stats) 1889 ), 1890 forward_input=FunctionInput(make_input(input_no_batch_shape)), 1891 reference_fn=no_batch_dim_reference_fn, 1892 desc='no_batch_dim') 1893 ] 1894 1895def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs): 1896 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1897 1898 return [ 1899 ModuleInput( 1900 constructor_input=FunctionInput([5], 1e-3), 1901 forward_input=FunctionInput(make_input((4, 5, 5))), 1902 desc='1d_elementwise_affine'), 1903 ModuleInput( 1904 constructor_input=FunctionInput([5], 1e-3), 1905 forward_input=FunctionInput(make_input((128, 5, 5))), 1906 desc='1d_elementwise_affine_large_batch'), 1907 ModuleInput( 1908 constructor_input=FunctionInput([5], 1e-3, False), 1909 forward_input=FunctionInput(make_input((4, 5, 5))), 1910 desc='1d_no_elementwise_affine'), 1911 ModuleInput( 1912 constructor_input=FunctionInput([2, 2, 5], 1e-3), 1913 forward_input=FunctionInput(make_input((4, 2, 2, 5))), 1914 desc='3d_elementwise_affine'), 1915 ModuleInput( 1916 constructor_input=FunctionInput([2, 2, 5], 1e-3, False), 1917 forward_input=FunctionInput(make_input((4, 2, 2, 5))), 1918 desc='3d_no_elementwise_affine'), 1919 ModuleInput( 1920 constructor_input=FunctionInput([5], 1e-3), 1921 forward_input=FunctionInput(make_input((0, 5))), 1922 desc='1d_empty_elementwise_affine'), 1923 ModuleInput( 1924 constructor_input=FunctionInput([2, 2, 5], 1e-3, elementwise_affine=True, bias=False), 1925 forward_input=FunctionInput(make_input((4, 2, 2, 5))), 1926 desc='3d_elementwise_affine_no_bias'), 1927 ] 1928 1929def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs): 1930 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1931 1932 def rms_norm_reference_fn(m, p, i): 1933 eps = m.eps 1934 if eps is None: 1935 eps = torch.finfo(i.dtype).eps 1936 ndim = i.ndim 1937 normalized_shape = m.normalized_shape 1938 weight = m.weight 1939 dims = [ndim - i - 1 for i in range(len(normalized_shape))] 1940 result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps) 1941 if weight is not None: 1942 result *= weight 1943 return result 1944 1945 return [ 1946 ModuleInput( 1947 constructor_input=FunctionInput([5], 1e-3), 1948 forward_input=FunctionInput(make_input((4, 5, 5))), 1949 desc='1d_elementwise_affine', 1950 reference_fn=rms_norm_reference_fn), 1951 ModuleInput( 1952 constructor_input=FunctionInput([5], 1e-3), 1953 forward_input=FunctionInput(make_input((128, 5, 5))), 1954 desc='1d_elementwise_affine_large_batch', 1955 reference_fn=rms_norm_reference_fn), 1956 ModuleInput( 1957 constructor_input=FunctionInput([5], 1e-3, False), 1958 forward_input=FunctionInput(make_input((4, 5, 5))), 1959 desc='1d_no_elementwise_affine', 1960 reference_fn=rms_norm_reference_fn), 1961 ModuleInput( 1962 constructor_input=FunctionInput([2, 2, 5], 1e-3), 1963 forward_input=FunctionInput(make_input((4, 2, 2, 5))), 1964 desc='3d_elementwise_affine', 1965 reference_fn=rms_norm_reference_fn), 1966 ModuleInput( 1967 constructor_input=FunctionInput([2, 2, 5], 1e-3, False), 1968 forward_input=FunctionInput(make_input((4, 2, 2, 5))), 1969 desc='3d_no_elementwise_affine', 1970 reference_fn=rms_norm_reference_fn), 1971 ModuleInput( 1972 constructor_input=FunctionInput([5], 1e-3), 1973 forward_input=FunctionInput(make_input((0, 5))), 1974 desc='1d_empty_elementwise_affine', 1975 reference_fn=rms_norm_reference_fn), 1976 ] 1977 1978 1979def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs): 1980 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1981 1982 return [ 1983 ModuleInput( 1984 constructor_input=FunctionInput(3,), 1985 forward_input=FunctionInput(make_input((1, 5, 7))), 1986 desc='1d'), 1987 ModuleInput( 1988 constructor_input=FunctionInput(2,), 1989 forward_input=FunctionInput(make_input((1, 5, 7, 7))), 1990 desc='2d_uneven_pad'), 1991 ModuleInput( 1992 constructor_input=FunctionInput(1, 1., 0.5, 2.), 1993 forward_input=FunctionInput(make_input((1, 5, 7, 7, 7))), 1994 desc='3d_custom_params'), 1995 ] 1996 1997 1998def module_inputs_torch_nn_LPPool1d(module_info, device, dtype, requires_grad, training, **kwargs): 1999 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2000 2001 return [ 2002 ModuleInput( 2003 constructor_input=FunctionInput(1.5, 2), 2004 forward_input=FunctionInput(make_input((1, 3, 7))), 2005 desc='norm'), 2006 ModuleInput( 2007 constructor_input=FunctionInput(2, 2, 3), 2008 forward_input=FunctionInput(make_input((1, 3, 7)))), 2009 ModuleInput( 2010 constructor_input=FunctionInput(2, 2, 3), 2011 forward_input=FunctionInput(make_input((3, 7))), 2012 reference_fn=no_batch_dim_reference_fn, 2013 desc='no_batch_dim'), 2014 ] 2015 2016 2017 2018def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, training, **kwargs): 2019 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2020 2021 return [ 2022 ModuleInput( 2023 constructor_input=FunctionInput(2, 2, 2), 2024 forward_input=FunctionInput(make_input((1, 3, 7, 7)))), 2025 ModuleInput( 2026 constructor_input=FunctionInput(2, 2, 2), 2027 forward_input=FunctionInput(make_input((3, 7, 7))), 2028 reference_fn=no_batch_dim_reference_fn, 2029 desc='no_batch_dim'), 2030 ModuleInput( 2031 constructor_input=FunctionInput(1.5, 2), 2032 forward_input=FunctionInput(make_input((1, 3, 7, 7))), 2033 desc='norm'), 2034 ] 2035 2036 2037def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs): 2038 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2039 2040 return [ 2041 ModuleInput( 2042 constructor_input=FunctionInput(2, 2, 2), 2043 forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))), 2044 ModuleInput( 2045 constructor_input=FunctionInput(2, 2, 2), 2046 forward_input=FunctionInput(make_input((3, 7, 7, 7))), 2047 reference_fn=no_batch_dim_reference_fn, 2048 desc='no_batch_dim'), 2049 ModuleInput( 2050 constructor_input=FunctionInput(1.5, 2), 2051 forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))), 2052 desc='norm'), 2053 ] 2054 2055 2056def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs): 2057 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2058 2059 return [ 2060 ModuleInput( 2061 constructor_input=FunctionInput(4), 2062 forward_input=FunctionInput(make_input((2, 10, 4))), 2063 desc='3d_input'), 2064 ModuleInput( 2065 constructor_input=FunctionInput(4, 4), 2066 forward_input=FunctionInput(make_input((2, 10, 4))), 2067 desc='stride'), 2068 ModuleInput( 2069 constructor_input=FunctionInput(4, return_indices=True), 2070 forward_input=FunctionInput(make_input((2, 10, 4))), 2071 desc='return_indices'), 2072 ] 2073 2074 2075def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs): 2076 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2077 2078 return [ 2079 ModuleInput( 2080 constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)), 2081 forward_input=FunctionInput(make_input((3, 7, 7))), 2082 desc='3d_input'), 2083 ModuleInput( 2084 constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)), 2085 forward_input=FunctionInput(make_input((1, 3, 7, 7))), 2086 desc='4d_input'), 2087 ModuleInput( 2088 constructor_input=FunctionInput((3, 3), (2, 2), (1, 1), return_indices=True), 2089 forward_input=FunctionInput(make_input((1, 3, 7, 7))), 2090 desc='return_indices'), 2091 ] 2092 2093def module_inputs_torch_nn_MaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs): 2094 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2095 2096 return [ 2097 ModuleInput( 2098 constructor_input=FunctionInput((2, 2, 2)), 2099 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5)))), 2100 ModuleInput( 2101 constructor_input=FunctionInput(2, (2, 2, 2)), 2102 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 2103 desc='stride'), 2104 ModuleInput( 2105 constructor_input=FunctionInput(2, 2, (1, 1, 1)), 2106 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 2107 desc='stride_padding'), 2108 ModuleInput( 2109 constructor_input=FunctionInput(2, 2, (1, 1, 1), return_indices=True), 2110 forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))), 2111 desc='return_indices'), 2112 ] 2113 2114 2115def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs): 2116 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2117 2118 def make_random_samples(): 2119 return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_() 2120 2121 return [ 2122 ModuleInput( 2123 constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), 2124 forward_input=FunctionInput(make_input((1, 3, 5, 7))), 2125 desc='ratio'), 2126 ModuleInput( 2127 constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()), 2128 forward_input=FunctionInput(make_input((1, 3, 7, 6))), 2129 desc='size'), 2130 ModuleInput( 2131 constructor_input=FunctionInput( 2132 2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True 2133 ), 2134 forward_input=FunctionInput(make_input((1, 3, 5, 7))), 2135 desc='ratio_return_indices'), 2136 ModuleInput( 2137 constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), 2138 forward_input=FunctionInput(make_input((3, 5, 7))), 2139 reference_fn=no_batch_dim_reference_fn, 2140 desc='ratio_no_batch_dim'), 2141 ModuleInput( 2142 constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()), 2143 forward_input=FunctionInput(make_input((3, 7, 6))), 2144 reference_fn=no_batch_dim_reference_fn, 2145 desc='size_no_batch_dim'), 2146 ] 2147 2148 2149def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs): 2150 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2151 2152 def make_random_samples(): 2153 return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_() 2154 2155 return [ 2156 ModuleInput( 2157 constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), 2158 forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))), 2159 desc='ratio'), 2160 ModuleInput( 2161 constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()), 2162 forward_input=FunctionInput(make_input((2, 4, 7, 7, 7))), 2163 desc='size'), 2164 ModuleInput( 2165 constructor_input=FunctionInput((4, 2, 3), output_size=(10, 3, 2), _random_samples=make_random_samples()), 2166 forward_input=FunctionInput(make_input((2, 4, 16, 7, 5))), 2167 desc='asymsize'), 2168 ModuleInput( 2169 constructor_input=FunctionInput( 2170 2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True 2171 ), 2172 forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))), 2173 desc='ratio_return_indices'), 2174 ModuleInput( 2175 constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()), 2176 forward_input=FunctionInput(make_input((4, 5, 5, 5))), 2177 reference_fn=no_batch_dim_reference_fn, 2178 desc='ratio_no_batch_dim'), 2179 ModuleInput( 2180 constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()), 2181 forward_input=FunctionInput(make_input((4, 7, 7, 7))), 2182 reference_fn=no_batch_dim_reference_fn, 2183 desc='size_no_batch_dim'), 2184 ] 2185 2186 2187def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs): 2188 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2189 2190 return [ 2191 ModuleInput( 2192 constructor_input=FunctionInput(), 2193 forward_input=FunctionInput(make_input(())), 2194 desc='scalar' 2195 ), 2196 ModuleInput( 2197 constructor_input=FunctionInput(), 2198 forward_input=FunctionInput(make_input(4)), 2199 reference_fn=no_batch_dim_reference_fn, 2200 desc='no_batch_dim', 2201 ), 2202 ModuleInput( 2203 constructor_input=FunctionInput(), 2204 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 2205 desc='channels_last_mem_format' 2206 ), 2207 ModuleInput( 2208 constructor_input=FunctionInput(), 2209 forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))), 2210 desc='channels_last_3d_mem_format' 2211 ) 2212 ] 2213 2214 2215def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad, training, **kwargs): 2216 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2217 2218 return [ 2219 ModuleInput( 2220 constructor_input=FunctionInput(), 2221 forward_input=FunctionInput(make_input(())), 2222 reference_fn=lambda m, p, i: i.sigmoid().log(), 2223 desc='scalar' 2224 ), 2225 ModuleInput( 2226 constructor_input=FunctionInput(), 2227 forward_input=FunctionInput(make_input((2, 3, 4))), 2228 reference_fn=lambda m, p, i: i.sigmoid().log(), 2229 ), 2230 ModuleInput( 2231 constructor_input=FunctionInput(), 2232 forward_input=FunctionInput(make_input(4)), 2233 reference_fn=no_batch_dim_reference_fn, 2234 desc='no_batch_dim', 2235 ), 2236 ] 2237 2238 2239def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, requires_grad, training, **kwargs): 2240 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2241 make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) 2242 2243 cases: List[Tuple[str, dict]] = [ 2244 ('', {}), 2245 ('reduction_sum', {'reduction': 'sum'}), 2246 ('reduction_mean', {'reduction': 'mean'}), 2247 ('reduction_none', {'reduction': 'none'}), 2248 ('margin', {'margin': 0.5}) 2249 ] 2250 2251 module_inputs = [] 2252 for desc, constructor_kwargs in cases: 2253 def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs): 2254 return marginrankingloss_reference(i1, i2, t, **constructor_kwargs) 2255 2256 module_inputs.append( 2257 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 2258 forward_input=FunctionInput(make_input((50,)), make_input((50,)), 2259 make_target((50,)).sign()), 2260 desc=desc, 2261 reference_fn=reference_fn) 2262 ) 2263 2264 return module_inputs 2265 2266 2267def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): 2268 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2269 make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) 2270 2271 cases: List[Tuple[str, dict]] = [ 2272 ('', {}), 2273 ('reduction_sum', {'reduction': 'sum'}), 2274 ('reduction_mean', {'reduction': 'mean'}), 2275 ('reduction_none', {'reduction': 'none'}), 2276 ] 2277 2278 module_inputs = [] 2279 for desc, constructor_kwargs in cases: 2280 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 2281 return multilabelmarginloss_reference(i, t, **constructor_kwargs) 2282 2283 module_inputs.append( 2284 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 2285 forward_input=FunctionInput(make_input((10,)), 2286 make_target((10), low=0, high=10)), 2287 desc=f'1d_{desc}', 2288 reference_fn=reference_fn) 2289 ) 2290 2291 module_inputs.append( 2292 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 2293 forward_input=FunctionInput(make_input((5, 10)), 2294 make_target((5, 10), low=0, high=10)), 2295 desc=desc, 2296 reference_fn=reference_fn) 2297 ) 2298 2299 return module_inputs 2300 2301 2302def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): 2303 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2304 make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) 2305 make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 2306 2307 cases: List[Tuple[str, dict]] = [ 2308 ('', {}), 2309 ('reduction_sum', {'reduction': 'sum'}), 2310 ('reduction_mean', {'reduction': 'mean'}), 2311 ('reduction_none', {'reduction': 'none'}), 2312 ('p', {'p': 2}), 2313 ('margin', {'margin': 0.5}), 2314 ('weights', {'weight': make_weight(10)}) 2315 ] 2316 2317 module_inputs = [] 2318 for desc, constructor_kwargs in cases: 2319 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 2320 return multimarginloss_reference(i, t, **constructor_kwargs) 2321 2322 module_inputs.append( 2323 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 2324 forward_input=FunctionInput(make_input((5, 10)), 2325 make_target((5), low=0, high=10)), 2326 desc=desc, 2327 reference_fn=reference_fn) 2328 ) 2329 2330 return module_inputs 2331 2332 2333def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): 2334 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2335 make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False) 2336 make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 2337 2338 cases: List[Tuple[str, dict]] = [ 2339 ('', {}), 2340 ('reduction_sum', {'reduction': 'sum'}), 2341 ('reduction_mean', {'reduction': 'mean'}), 2342 ('reduction_none', {'reduction': 'none'}), 2343 ('weight', {'weight': make_weight(10)}), 2344 ] 2345 2346 def multilabelsoftmargin_loss_reference_fn(m, p, i, t, reduction='mean', weight=None): 2347 result = t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log() 2348 if weight is not None: 2349 result *= weight 2350 result = (-result).sum(i.dim() - 1) / i.size(-1) 2351 2352 if reduction == 'none': 2353 return result 2354 elif reduction == 'mean': 2355 return result.mean() 2356 else: 2357 return result.sum() 2358 2359 module_inputs = [] 2360 for desc, constructor_kwargs in cases: 2361 module_inputs.append( 2362 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 2363 forward_input=FunctionInput(make_input((5, 10)), 2364 make_target((5, 10), low=0, high=2)), 2365 desc=desc, 2366 reference_fn=partial(multilabelsoftmargin_loss_reference_fn, **constructor_kwargs)) 2367 ) 2368 2369 return module_inputs 2370 2371 2372def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs): 2373 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2374 make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 2375 2376 cases: List[Tuple[str, dict]] = [ 2377 ('', {}), 2378 ('reduction_sum', {'reduction': 'sum'}), 2379 ('reduction_mean', {'reduction': 'mean'}), 2380 ('reduction_none', {'reduction': 'none'}), 2381 ] 2382 2383 module_inputs = [] 2384 for desc, constructor_kwargs in cases: 2385 def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): 2386 return softmarginloss_reference(i, t, **constructor_kwargs) 2387 2388 module_inputs.append( 2389 ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), 2390 forward_input=FunctionInput(make_input((5, 5)), 2391 make_target((5, 5)).sign()), 2392 desc=desc, 2393 reference_fn=reference_fn) 2394 ) 2395 2396 return module_inputs 2397 2398 2399def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs): 2400 # Reuse the TransformerEncoderLayer samples since the forward args are nearly the same. 2401 samples = [] 2402 for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer( 2403 None, device, dtype, requires_grad, training): 2404 # Construct a TransformerEncoderLayer object to pass to TransformerEncoder. 2405 l_args, l_kwargs = (layer_module_input.constructor_input.args, 2406 layer_module_input.constructor_input.kwargs) 2407 l_kwargs['device'] = device 2408 l_kwargs['dtype'] = dtype 2409 encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs) 2410 num_layers = 2 2411 # Note: TransformerEncoderLayer takes a "src_mask" while 2412 # TransformerEncoder takes a "mask"; rename kwarg appropriately. 2413 forward_input = layer_module_input.forward_input 2414 if 'src_mask' in forward_input.kwargs: 2415 forward_input.kwargs['mask'] = forward_input.kwargs['src_mask'] 2416 del forward_input.kwargs['src_mask'] 2417 samples.append(ModuleInput( 2418 constructor_input=FunctionInput(encoder_layer, num_layers), 2419 forward_input=forward_input, 2420 desc=layer_module_input.desc 2421 )) 2422 return samples 2423 2424def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs): 2425 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2426 2427 samples = [ 2428 ModuleInput( 2429 constructor_input=FunctionInput(4, 2, 16, 0.0), 2430 forward_input=FunctionInput( 2431 make_input((2, 3, 4)) 2432 ), 2433 desc='relu_activation' 2434 ), 2435 ModuleInput( 2436 constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu), 2437 forward_input=FunctionInput( 2438 make_input((2, 3, 4)) 2439 ), 2440 desc='gelu_activation' 2441 ), 2442 ModuleInput( 2443 constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False), 2444 forward_input=FunctionInput( 2445 make_input((2, 3, 4)) 2446 ), 2447 desc='no_bias' 2448 ), ] 2449 2450 # Samples below are for validating the no-batch-dim support. 2451 key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) 2452 attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) 2453 for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \ 2454 itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): 2455 samples.append( 2456 ModuleInput( 2457 constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, 2458 dropout=0.0, batch_first=batch_first, 2459 norm_first=norm_first, bias=bias), 2460 forward_input=FunctionInput( 2461 make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask 2462 ), 2463 reference_fn=partial(no_batch_dim_reference_fn, 2464 batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}), 2465 desc=f'no_batch_dim_batch_first_{batch_first}' 2466 )) 2467 2468 # Samples below where we pass reference_fn are for validating the fast path, 2469 # since the fast path requires no_grad mode, we run the fast path in .eval() 2470 # and no_grad() in the reference_fn and verify that against the results in train mode. 2471 def fast_path_reference_fn(module, parameters, *args, **kwargs): 2472 assert module.training 2473 module.train(False) 2474 with torch.no_grad(): 2475 output = module(*args, **kwargs) 2476 module.train(True) 2477 return output 2478 2479 if training: 2480 for norm_first, bias in itertools.product((True, False), (True, False)): 2481 samples.append( 2482 ModuleInput( 2483 constructor_input=FunctionInput( 2484 4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias 2485 ), 2486 forward_input=FunctionInput( 2487 make_input((2, 3, 4)), 2488 ), 2489 # fastpath doesn't run when bias=False 2490 reference_fn=fast_path_reference_fn if bias else None, 2491 desc=f'fastpath_{bias}_norm_first_{norm_first}' 2492 ) 2493 ) 2494 2495 return samples 2496 2497 2498def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs): 2499 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2500 2501 samples = [ 2502 ModuleInput( 2503 constructor_input=FunctionInput(4, 2, 16, 0.0), 2504 forward_input=FunctionInput( 2505 make_input((2, 3, 4)), make_input((2, 3, 4)) 2506 ), 2507 desc='relu_activation' 2508 ), 2509 ModuleInput( 2510 constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu), 2511 forward_input=FunctionInput( 2512 make_input((2, 3, 4)), make_input((2, 3, 4)) 2513 ), 2514 desc='gelu_activation' 2515 ), 2516 ModuleInput( 2517 constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False), 2518 forward_input=FunctionInput( 2519 make_input((2, 3, 4)), make_input((2, 3, 4)) 2520 ), 2521 desc='no_bias' 2522 ), ] 2523 2524 key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) 2525 attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) 2526 for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \ 2527 itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): 2528 # Using same mask for tgt and memory 2529 memory_mask = tgt_mask 2530 memory_key_padding_mask = tgt_key_padding_mask 2531 samples.append( 2532 ModuleInput( 2533 constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, 2534 dropout=0.0, batch_first=batch_first, 2535 norm_first=norm_first, bias=bias), 2536 forward_input=FunctionInput( 2537 make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask, 2538 tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask 2539 ), 2540 reference_fn=partial(no_batch_dim_reference_fn, 2541 batch_first=batch_first, 2542 kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}), 2543 desc=f'no_batch_dim_batch_first_{batch_first}' 2544 )) 2545 src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4)) 2546 if not batch_first: 2547 src, tgt = src.transpose(0, 1), tgt.transpose(0, 1) 2548 if tgt_key_padding_mask is not None: 2549 memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2 2550 samples.append( 2551 ModuleInput( 2552 constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, 2553 dropout=0.0, batch_first=batch_first, 2554 norm_first=norm_first, bias=bias), 2555 forward_input=FunctionInput( 2556 src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask, 2557 tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask 2558 ), 2559 desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}' 2560 )) 2561 2562 return samples 2563 2564 2565def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs): 2566 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2567 samples = [] 2568 # Samples below are for validating the no-batch-dim support. 2569 key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) 2570 attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) 2571 for mask, key_padding_mask, norm_first, bias, batch_first in \ 2572 itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): 2573 # Using same mask for tgt and memory 2574 src_mask , tgt_mask = (mask,) * 2 2575 src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2 2576 samples.append( 2577 ModuleInput( 2578 constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, 2579 num_encoder_layers=1, num_decoder_layers=1, 2580 dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias), 2581 forward_input=FunctionInput( 2582 make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask, 2583 tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask 2584 ), 2585 reference_fn=partial(no_batch_dim_reference_fn, 2586 batch_first=batch_first, 2587 kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}), 2588 desc=f'no_batch_dim_batch_first_{batch_first}' 2589 )) 2590 2591 src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4)) 2592 if not batch_first: 2593 src = src.transpose(0, 1) 2594 tgt = tgt.transpose(0, 1) 2595 if key_padding_mask is not None: 2596 src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2 2597 2598 samples.append( 2599 ModuleInput( 2600 constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, 2601 num_encoder_layers=1, num_decoder_layers=1, 2602 dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias), 2603 forward_input=FunctionInput( 2604 src, tgt, tgt_mask=tgt_mask, src_mask=src_mask, 2605 tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask 2606 ), 2607 )) 2608 return samples 2609 2610 2611def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs): 2612 make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False) 2613 return [ 2614 ModuleInput( 2615 constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3), 2616 forward_input=FunctionInput(make_empty(2, 3).random_(4)) 2617 ), 2618 ModuleInput( 2619 constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3), 2620 forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)), 2621 desc='discontiguous' 2622 ), 2623 ] 2624 2625 2626def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs): 2627 # Currently all samples below are for validating the no-batch-dim support. 2628 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2629 samples = [] 2630 bool_vals = (True, False) 2631 key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) 2632 attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3, 3))) 2633 products = itertools.product(bool_vals, bool_vals, bool_vals, key_padding_masks, attn_masks) 2634 for bias, add_bias_kv, add_zero_attn, key_padding_mask, attn_mask in products: 2635 samples.append( 2636 ModuleInput( 2637 constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=True, 2638 bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn), 2639 forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)), 2640 key_padding_mask=key_padding_mask, attn_mask=attn_mask), 2641 reference_fn=no_batch_dim_reference_mha, 2642 ) 2643 ) 2644 samples.append( 2645 ModuleInput( 2646 constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=False, 2647 bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn), 2648 forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)), 2649 key_padding_mask=key_padding_mask, attn_mask=attn_mask), 2650 reference_fn=partial(no_batch_dim_reference_mha, batch_first=False), 2651 ) 2652 ) 2653 2654 return samples 2655 2656 2657def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs): 2658 # Currently all samples below are for validating the no-batch-dim support. 2659 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2660 samples = [ 2661 ModuleInput( 2662 constructor_input=FunctionInput(5, 10), 2663 forward_input=FunctionInput(make_input(5), make_input(10)), 2664 reference_fn=no_batch_dim_reference_fn, 2665 ), 2666 ModuleInput( 2667 constructor_input=FunctionInput(5, 10, bias=True), 2668 forward_input=FunctionInput(make_input(5), make_input(10)), 2669 reference_fn=no_batch_dim_reference_fn, 2670 ) 2671 ] 2672 2673 is_rnn = kwargs.get('is_rnn', False) 2674 if is_rnn: 2675 # RNN also supports `nonlinearity` argument. 2676 # `tanh` is the default, so we check with `relu` 2677 samples.append( 2678 ModuleInput( 2679 constructor_input=FunctionInput(5, 10, bias=True, nonlinearity='relu'), 2680 forward_input=FunctionInput(make_input(5), make_input(10)), 2681 reference_fn=no_batch_dim_reference_fn, 2682 ) 2683 ) 2684 2685 return samples 2686 2687 2688def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs): 2689 # Currently all samples below are for validating the no-batch-dim support. 2690 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2691 samples = ( 2692 ModuleInput( 2693 constructor_input=FunctionInput(5, 10), 2694 forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))), 2695 reference_fn=no_batch_dim_reference_lstmcell, 2696 ), 2697 ModuleInput( 2698 constructor_input=FunctionInput(5, 10, bias=True), 2699 forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))), 2700 reference_fn=no_batch_dim_reference_lstmcell, 2701 ), 2702 ) 2703 2704 return samples 2705 2706def make_packed_sequence(inp, batch_sizes): 2707 required_grad = inp.requires_grad 2708 inp.requires_grad_(False) # user won't have access to inp so won't be able to get its grads 2709 seq = pack_padded_sequence(inp, batch_sizes) 2710 seq.data.requires_grad_(required_grad) 2711 return seq 2712 2713 2714def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs): 2715 # Currently all samples below are for validating the no-batch-dim support. 2716 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2717 is_rnn = kwargs['is_rnn'] 2718 nonlinearity = ('relu', 'tanh') 2719 bias = (False, True) 2720 batch_first = (False, True) 2721 bidirectional = (False, True) 2722 2723 samples = [] 2724 if is_rnn: 2725 prod_gen = product(nonlinearity, bias, batch_first, bidirectional) 2726 else: 2727 prod_gen = product(bias, batch_first, bidirectional) 2728 2729 for args in prod_gen: 2730 if is_rnn: 2731 nl, b, b_f, bidir = args 2732 else: 2733 b, b_f, bidir = args 2734 2735 cons_args = {'input_size': 2, 'hidden_size': 2, 'num_layers': 2, 2736 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} 2737 cons_args_hidden = {'input_size': 2, 'hidden_size': 3, 'num_layers': 2, 2738 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} 2739 2740 if is_rnn: 2741 cons_args['nonlinearity'] = nl 2742 cons_args_hidden['nonlinearity'] = nl 2743 samples.append( 2744 ModuleInput( 2745 constructor_input=FunctionInput(**cons_args), 2746 forward_input=FunctionInput(make_input((3, 2))), 2747 reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), 2748 ) 2749 ) 2750 samples.append( 2751 ModuleInput( 2752 constructor_input=FunctionInput(**cons_args_hidden), 2753 forward_input=FunctionInput(make_input((3, 2)), make_input((4 if bidir else 2, 3))), 2754 reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), 2755 ) 2756 ) 2757 if with_packed_sequence: 2758 samples.append( 2759 ModuleInput( 2760 constructor_input=FunctionInput(**cons_args), 2761 forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))), 2762 reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), 2763 ) 2764 ) 2765 samples.append( 2766 ModuleInput( 2767 constructor_input=FunctionInput(**cons_args), 2768 forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))), 2769 reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), 2770 ) 2771 ) 2772 2773 return samples 2774 2775 2776def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs): 2777 # Currently all samples below are for validating the no-batch-dim support. 2778 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2779 bias = (False, True) 2780 batch_first = (False, True) 2781 bidirectional = (False, True) 2782 proj_sizes = (0, 2) 2783 2784 samples = [] 2785 prod_gen = product(bias, batch_first, bidirectional, proj_sizes) 2786 2787 for args in prod_gen: 2788 b, b_f, bidir, proj_size = args 2789 hidden_size = 3 2790 cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size, 2791 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} 2792 cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size, 2793 'batch_first': b_f, 'bias': b, 'bidirectional': bidir} 2794 2795 samples.append( 2796 ModuleInput( 2797 constructor_input=FunctionInput(**cons_args), 2798 forward_input=FunctionInput(make_input((2, 2))), 2799 reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f), 2800 ) 2801 ) 2802 2803 h_out = proj_size if proj_size > 0 else hidden_size 2804 hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size))) 2805 samples.append( 2806 ModuleInput( 2807 constructor_input=FunctionInput(**cons_args_hidden), 2808 forward_input=FunctionInput(make_input((3, 2)), hx), 2809 reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f), 2810 ) 2811 ) 2812 2813 2814 return samples 2815 2816 2817 2818def module_inputs_torch_nn_ReflectionPad1d(module_info, device, dtype, requires_grad, training, **kwargs): 2819 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2820 2821 return [ 2822 ModuleInput( 2823 constructor_input=FunctionInput(1), 2824 forward_input=FunctionInput(make_input((2, 3))), 2825 reference_fn=no_batch_dim_reference_fn, 2826 ), 2827 ModuleInput( 2828 constructor_input=FunctionInput((1, 2)), 2829 forward_input=FunctionInput(make_input((2, 3, 4))), 2830 ), 2831 ] 2832 2833def module_inputs_torch_nn_ReflectionPad2d(module_info, device, dtype, requires_grad, training, **kwargs): 2834 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2835 2836 return [ 2837 ModuleInput( 2838 constructor_input=FunctionInput(1), 2839 forward_input=FunctionInput(make_input((3, 4, 5))), 2840 reference_fn=no_batch_dim_reference_fn, 2841 ), 2842 ModuleInput( 2843 constructor_input=FunctionInput((1, 2, 3, 4)), 2844 forward_input=FunctionInput(make_input((3, 4, 5, 6))), 2845 ), 2846 ] 2847 2848def module_inputs_torch_nn_ReflectionPad3d(module_info, device, dtype, requires_grad, training, **kwargs): 2849 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2850 2851 return [ 2852 ModuleInput( 2853 constructor_input=FunctionInput(1), 2854 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 2855 reference_fn=no_batch_dim_reference_fn 2856 ), 2857 ModuleInput( 2858 constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)), 2859 forward_input=FunctionInput(make_input((3, 3, 3, 3, 3))), 2860 ), 2861 ] 2862 2863def module_inputs_torch_nn_ReplicationPad1d(module_info, device, dtype, requires_grad, training, **kwargs): 2864 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2865 2866 return [ 2867 ModuleInput( 2868 constructor_input=FunctionInput(1), 2869 forward_input=FunctionInput(make_input((3, 4))), 2870 reference_fn=no_batch_dim_reference_fn 2871 ), 2872 ModuleInput( 2873 constructor_input=FunctionInput((1, 2)), 2874 forward_input=FunctionInput(make_input((3, 4, 5))), 2875 ), 2876 ] 2877 2878def module_inputs_torch_nn_ReplicationPad2d(module_info, device, dtype, requires_grad, training, **kwargs): 2879 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2880 2881 return [ 2882 ModuleInput( 2883 constructor_input=FunctionInput(1), 2884 forward_input=FunctionInput(make_input((3, 4, 5))), 2885 reference_fn=no_batch_dim_reference_fn, 2886 ), 2887 ModuleInput( 2888 constructor_input=FunctionInput((1, 2, 3, 4)), 2889 forward_input=FunctionInput(make_input((3, 4, 5, 6))), 2890 ), 2891 ] 2892 2893def module_inputs_torch_nn_ReplicationPad3d(module_info, device, dtype, requires_grad, training, **kwargs): 2894 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2895 2896 return [ 2897 ModuleInput( 2898 constructor_input=FunctionInput(1), 2899 forward_input=FunctionInput(make_input((3, 4, 5, 6))), 2900 reference_fn=no_batch_dim_reference_fn, 2901 ), 2902 ModuleInput( 2903 constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)), 2904 forward_input=FunctionInput(make_input((3, 4, 5, 6, 7))), 2905 ), 2906 ] 2907 2908def module_inputs_torch_nn_ZeroPad1d(module_info, device, dtype, requires_grad, training, **kwargs): 2909 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2910 2911 return [ 2912 ModuleInput( 2913 constructor_input=FunctionInput(1), 2914 forward_input=FunctionInput(make_input((3, 4))), 2915 reference_fn=no_batch_dim_reference_fn, 2916 ), 2917 ModuleInput( 2918 constructor_input=FunctionInput((1, 2)), 2919 forward_input=FunctionInput(make_input((3, 4, 5))), 2920 ), 2921 ] 2922 2923def module_inputs_torch_nn_ZeroPad2d(module_info, device, dtype, requires_grad, training, **kwargs): 2924 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2925 2926 return [ 2927 ModuleInput( 2928 constructor_input=FunctionInput(1), 2929 forward_input=FunctionInput(make_input((1, 2, 3))), 2930 reference_fn=no_batch_dim_reference_fn 2931 ), 2932 ModuleInput( 2933 constructor_input=FunctionInput((1, 2, 3, 4)), 2934 forward_input=FunctionInput(make_input((1, 2, 3, 4))), 2935 ), 2936 ] 2937 2938def module_inputs_torch_nn_ZeroPad3d(module_info, device, dtype, requires_grad, training, **kwargs): 2939 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2940 2941 return [ 2942 ModuleInput( 2943 constructor_input=FunctionInput(1), 2944 forward_input=FunctionInput(make_input((3, 4, 5, 6))), 2945 reference_fn=no_batch_dim_reference_fn, 2946 ), 2947 ModuleInput( 2948 constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)), 2949 forward_input=FunctionInput(make_input((1, 2, 3, 4, 5))), 2950 ), 2951 ] 2952 2953def module_inputs_torch_nn_ConstantPad1d(module_info, device, dtype, requires_grad, training, **kwargs): 2954 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2955 2956 return [ 2957 ModuleInput( 2958 constructor_input=FunctionInput(1, 2), 2959 forward_input=FunctionInput(make_input((3, 4))), 2960 reference_fn=no_batch_dim_reference_fn, 2961 ), 2962 ModuleInput( 2963 constructor_input=FunctionInput((1, 2), 3), 2964 forward_input=FunctionInput(make_input((3, 4, 5))), 2965 ), 2966 ] 2967 2968def module_inputs_torch_nn_ConstantPad2d(module_info, device, dtype, requires_grad, training, **kwargs): 2969 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2970 2971 return [ 2972 ModuleInput( 2973 constructor_input=FunctionInput(1, 3), 2974 forward_input=FunctionInput(make_input((3, 4, 5))), 2975 reference_fn=no_batch_dim_reference_fn 2976 ), 2977 ModuleInput( 2978 constructor_input=FunctionInput((1, 2, 3, 4), 5), 2979 forward_input=FunctionInput(make_input((1, 2, 3, 4))), 2980 ), 2981 ] 2982 2983def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_grad, training, **kwargs): 2984 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2985 2986 return [ 2987 ModuleInput( 2988 constructor_input=FunctionInput(1, 3), 2989 forward_input=FunctionInput(make_input((3, 4, 5, 6))), 2990 reference_fn=no_batch_dim_reference_fn, 2991 ), 2992 ModuleInput( 2993 constructor_input=FunctionInput((1, 2, 3, 4, 5, 6), 7), 2994 forward_input=FunctionInput(make_input((1, 2, 1, 2, 1))), 2995 ), 2996 ] 2997 2998def module_inputs_torch_nn_CircularPad1d(module_info, device, dtype, requires_grad, training, **kwargs): 2999 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3000 3001 def padding1d_circular_ref(inp, pad): 3002 r""" input: 3003 [[[0., 1., 2.], 3004 [3., 4., 5.]]] 3005 pad: (1, 2) 3006 output: 3007 [[[2., 0., 1., 2., 0., 1.], 3008 [5., 3., 4., 5., 3., 4.]]] 3009 """ 3010 return torch.cat([inp[:, :, -pad[0]:], inp, inp[:, :, :pad[1]]], dim=2) 3011 3012 return [ 3013 ModuleInput( 3014 constructor_input=FunctionInput(1), 3015 forward_input=FunctionInput(make_input((3, 4))), 3016 reference_fn=no_batch_dim_reference_fn 3017 ), 3018 ModuleInput( 3019 constructor_input=FunctionInput((1, 2)), 3020 forward_input=FunctionInput(make_input((1, 2, 3))), 3021 reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding), 3022 ), 3023 ModuleInput( 3024 constructor_input=FunctionInput((3, 1)), 3025 forward_input=FunctionInput(make_input((1, 2, 3))), 3026 reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding), 3027 ), 3028 ModuleInput( 3029 constructor_input=FunctionInput((3, 3)), 3030 forward_input=FunctionInput(make_input((1, 2, 3))), 3031 reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding), 3032 ), 3033 ] 3034 3035def module_inputs_torch_nn_CircularPad2d(module_info, device, dtype, requires_grad, training, **kwargs): 3036 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3037 3038 def padding2d_circular_ref(inp, pad): 3039 r"""input: 3040 [[[[0., 1., 2], 3041 [3., 4., 5.]]]] 3042 pad: (1, 2, 2, 1) 3043 output: 3044 [[[[2., 0., 1., 2., 0., 1.], 3045 [5., 3., 4., 5., 3., 4.], 3046 [2., 0., 1., 2., 0., 1.], 3047 [5., 3., 4., 5., 3., 4.], 3048 [2., 0., 1., 2., 0., 1.]]]] 3049 """ 3050 inp = torch.cat([inp[:, :, -pad[2]:], inp, inp[:, :, :pad[3]]], dim=2) 3051 return torch.cat([inp[:, :, :, -pad[0]:], inp, inp[:, :, :, :pad[1]]], dim=3) 3052 3053 return [ 3054 ModuleInput( 3055 constructor_input=FunctionInput(1), 3056 forward_input=FunctionInput(make_input((3, 4, 5))), 3057 reference_fn=no_batch_dim_reference_fn, 3058 ), 3059 ModuleInput( 3060 constructor_input=FunctionInput((1, 2, 2, 1)), 3061 forward_input=FunctionInput(make_input((1, 1, 2, 3))), 3062 reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding), 3063 ), 3064 ModuleInput( 3065 constructor_input=FunctionInput((2, 3, 2, 2)), 3066 forward_input=FunctionInput(make_input((1, 1, 2, 3))), 3067 reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding), 3068 ), 3069 ModuleInput( 3070 constructor_input=FunctionInput((3, 3, 3, 1)), 3071 forward_input=FunctionInput(make_input((1, 1, 3, 3))), 3072 reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding), 3073 ), 3074 ] 3075 3076def module_inputs_torch_nn_CircularPad3d(module_info, device, dtype, requires_grad, training, **kwargs): 3077 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3078 3079 3080 def padding3d_circular_ref(inp, pad): 3081 r"""input: 3082 [[[[[ 0., 1., 2.], 3083 [ 3., 4., 5.]], 3084 [[ 6., 7., 8.], 3085 [ 9., 10., 11.]]]]] 3086 pad: (1, 2, 2, 1, 1, 2) 3087 output: [[[[[ 8., 6., 7., 8., 6., 7.], 3088 [11., 9., 10., 11., 9., 10.], 3089 [ 8., 6., 7., 8., 6., 7.], 3090 [11., 9., 10., 11., 9., 10.], 3091 [ 8., 6., 7., 8., 6., 7.]], 3092 3093 [[ 2., 0., 1., 2., 0., 1.], 3094 [ 5., 3., 4., 5., 3., 4.], 3095 [ 2., 0., 1., 2., 0., 1.], 3096 [ 5., 3., 4., 5., 3., 4.], 3097 [ 2., 0., 1., 2., 0., 1.]], 3098 3099 [[ 8., 6., 7., 8., 6., 7.], 3100 [11., 9., 10., 11., 9., 10.], 3101 [ 8., 6., 7., 8., 6., 7.], 3102 [11., 9., 10., 11., 9., 10.], 3103 [ 8., 6., 7., 8., 6., 7.]], 3104 3105 [[ 2., 0., 1., 2., 0., 1.], 3106 [ 5., 3., 4., 5., 3., 4.], 3107 [ 2., 0., 1., 2., 0., 1.], 3108 [ 5., 3., 4., 5., 3., 4.], 3109 [ 2., 0., 1., 2., 0., 1.]], 3110 3111 [[ 8., 6., 7., 8., 6., 7.], 3112 [11., 9., 10., 11., 9., 10.], 3113 [ 8., 6., 7., 8., 6., 7.], 3114 [11., 9., 10., 11., 9., 10.], 3115 [ 8., 6., 7., 8., 6., 7.]]]]] 3116 """ 3117 inp = torch.cat([inp[:, :, -pad[4]:], inp, inp[:, :, :pad[5]]], dim=2) 3118 inp = torch.cat([inp[:, :, :, -pad[2]:], inp, inp[:, :, :, :pad[3]]], dim=3) 3119 return torch.cat([inp[:, :, :, :, -pad[0]:], inp, inp[:, :, :, :, :pad[1]]], dim=4) 3120 3121 return [ 3122 ModuleInput( 3123 constructor_input=FunctionInput(1), 3124 forward_input=FunctionInput(make_input((3, 4, 5, 6))), 3125 reference_fn=no_batch_dim_reference_fn, 3126 ), 3127 ModuleInput( 3128 constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)), 3129 forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))), 3130 reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding) 3131 ), 3132 ModuleInput( 3133 constructor_input=FunctionInput((3, 2, 2, 1, 1, 2)), 3134 forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))), 3135 reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding) 3136 ), 3137 ModuleInput( 3138 constructor_input=FunctionInput((3, 3, 2, 1, 2, 2)), 3139 forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))), 3140 reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding) 3141 ), 3142 ] 3143 3144 3145# All these operators share similar issues on cuDNN and MIOpen 3146rnn_gru_lstm_module_info_decorators = ( 3147 # RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward. 3148 # We could not generate a fallback 3149 DecorateInfo( 3150 unittest.expectedFailure, "TestModule", "test_grad", 3151 active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda' 3152 ), 3153 # NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. 3154 # Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API 3155 DecorateInfo( 3156 unittest.expectedFailure, "TestModule", "test_gradgrad", 3157 active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda' 3158 ), 3159 # CUDNN GRU doesn't accept non-contiguous hx 3160 DecorateInfo( 3161 unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", 3162 active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda' 3163 ), 3164 # MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float). 3165 DecorateInfo( 3166 unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors", 3167 active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda' 3168 ), 3169 DecorateInfo( 3170 skipCUDAVersionIn([(11, 7)]), "TestExpandedWeightModule", "test_module", 3171 device_type='cuda' 3172 ), 3173 DecorateInfo( 3174 skipCUDAVersionIn([(11, 7)]), "TestDecomp", "test_rnn_decomp_module", 3175 device_type='cuda' 3176 ) 3177) 3178 3179# Start of module error inputs functions. 3180 3181def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs): 3182 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3183 samples = [ 3184 ErrorModuleInput( 3185 ModuleInput( 3186 constructor_input=FunctionInput(10, 20), 3187 forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)), 3188 ), 3189 error_on=ModuleErrorEnum.FORWARD_ERROR, 3190 error_type=RuntimeError, 3191 error_regex="input has inconsistent input_size: got 11 expected 10" 3192 ), 3193 ErrorModuleInput( 3194 ModuleInput( 3195 constructor_input=FunctionInput(10, 20), 3196 forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)), 3197 ), 3198 error_on=ModuleErrorEnum.FORWARD_ERROR, 3199 error_type=RuntimeError, 3200 error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" 3201 ), 3202 ErrorModuleInput( 3203 ModuleInput( 3204 constructor_input=FunctionInput(10, 20), 3205 forward_input=FunctionInput(make_input(3, 10), make_input(5, 20)), 3206 ), 3207 error_on=ModuleErrorEnum.FORWARD_ERROR, 3208 error_type=RuntimeError, 3209 error_regex="Input batch size 3 doesn't match hidden0 batch size 5" 3210 ), 3211 ErrorModuleInput( 3212 ModuleInput( 3213 constructor_input=FunctionInput(10, 20), 3214 forward_input=FunctionInput(make_input(3, 10), make_input(3, 1, 1, 20)), 3215 ), 3216 error_on=ModuleErrorEnum.FORWARD_ERROR, 3217 error_type=ValueError, 3218 error_regex="Expected hidden to be 1D or 2D, got 4D instead" 3219 ), 3220 ErrorModuleInput( 3221 ModuleInput( 3222 constructor_input=FunctionInput(10, 20, 'relu'), 3223 forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)), 3224 ), 3225 error_on=ModuleErrorEnum.FORWARD_ERROR, 3226 error_type=RuntimeError, 3227 error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" 3228 ), 3229 ErrorModuleInput( 3230 ModuleInput( 3231 constructor_input=FunctionInput(10, 20, 'tanh'), 3232 forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)), 3233 ), 3234 error_on=ModuleErrorEnum.FORWARD_ERROR, 3235 error_type=RuntimeError, 3236 error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" 3237 ), 3238 ] 3239 return samples 3240 3241def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs): 3242 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3243 samples = [ 3244 ErrorModuleInput( 3245 ModuleInput( 3246 constructor_input=FunctionInput(10, 20), 3247 forward_input=FunctionInput(make_input(3, 11), (make_input(3, 20), make_input(3, 20))), 3248 ), 3249 error_on=ModuleErrorEnum.FORWARD_ERROR, 3250 error_type=RuntimeError, 3251 error_regex="input has inconsistent input_size: got 11 expected 10" 3252 ), 3253 ErrorModuleInput( 3254 ModuleInput( 3255 constructor_input=FunctionInput(10, 20), 3256 forward_input=FunctionInput(make_input(3, 10), (make_input(3, 21), make_input(3, 21))), 3257 ), 3258 error_on=ModuleErrorEnum.FORWARD_ERROR, 3259 error_type=RuntimeError, 3260 error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20" 3261 ), 3262 ErrorModuleInput( 3263 ModuleInput( 3264 constructor_input=FunctionInput(10, 20), 3265 forward_input=FunctionInput(make_input(3, 10), (make_input(5, 20), make_input(5, 20))), 3266 ), 3267 error_on=ModuleErrorEnum.FORWARD_ERROR, 3268 error_type=RuntimeError, 3269 error_regex="Input batch size 3 doesn't match hidden0 batch size 5" 3270 ), 3271 ErrorModuleInput( 3272 ModuleInput( 3273 constructor_input=FunctionInput(10, 20), 3274 forward_input=FunctionInput(make_input(3, 10), (make_input(3, 1, 1, 20), make_input(3, 1, 1, 20))), 3275 ), 3276 error_on=ModuleErrorEnum.FORWARD_ERROR, 3277 error_type=ValueError, 3278 error_regex="Expected hx\\[0\\] to be 1D or 2D, got 4D instead" 3279 ), 3280 ] 3281 return samples 3282 3283 3284def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs): 3285 samples = [ 3286 ErrorModuleInput( 3287 ModuleInput(constructor_input=FunctionInput(10, 0, 1)), 3288 error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, 3289 error_type=ValueError, 3290 error_regex="hidden_size must be greater than zero" 3291 ), 3292 ErrorModuleInput( 3293 ModuleInput(constructor_input=FunctionInput(10, 10, 0)), 3294 error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, 3295 error_type=ValueError, 3296 error_regex="num_layers must be greater than zero" 3297 ), 3298 ] 3299 return samples 3300 3301def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs): 3302 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3303 3304 is_constant = kwargs.get('is_constant', False) 3305 3306 return [ 3307 ErrorModuleInput( 3308 ModuleInput( 3309 constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3), 3310 forward_input=FunctionInput(make_input((2, 3, 4, 5))), 3311 ), 3312 error_on=ModuleErrorEnum.FORWARD_ERROR, 3313 error_type=ValueError, 3314 error_regex=r"expected 2D or 3D input \(got 4D input\)", 3315 3316 ), 3317 ] 3318 3319def module_error_inputs_torch_nn_Pad2d(module_info, device, dtype, requires_grad, training, **kwargs): 3320 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3321 3322 is_constant = kwargs.get('is_constant', False) 3323 3324 return [ 3325 ErrorModuleInput( 3326 ModuleInput( 3327 constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3), 3328 forward_input=FunctionInput(make_input((2, 3))), 3329 ), 3330 error_on=ModuleErrorEnum.FORWARD_ERROR, 3331 error_type=ValueError, 3332 error_regex=r"expected 3D or 4D input \(got 2D input\)", 3333 3334 ), 3335 ] 3336 3337def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad, training, **kwargs): 3338 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3339 3340 is_constant = kwargs.get('is_constant', False) 3341 3342 return [ 3343 ErrorModuleInput( 3344 ModuleInput( 3345 constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3), 3346 forward_input=FunctionInput(make_input((2, 3))), 3347 ), 3348 error_on=ModuleErrorEnum.FORWARD_ERROR, 3349 error_type=ValueError, 3350 error_regex=r"expected 4D or 5D input \(got 2D input\)", 3351 3352 ), 3353 ] 3354 3355 3356_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0) 3357 3358 3359# Database of ModuleInfo entries in alphabetical order. 3360module_db: List[ModuleInfo] = [ 3361 ModuleInfo(torch.nn.AdaptiveAvgPool1d, 3362 module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d, 3363 skips=( 3364 # Fails on MPS backend if input/output sizes are not divisible 3365 DecorateInfo(skipMPS),) 3366 ), 3367 ModuleInfo(torch.nn.AdaptiveAvgPool2d, 3368 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3369 module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d, 3370 skips=( 3371 # Fails on MPS backend if input/output sizes are not divisible 3372 DecorateInfo(skipMPS), 3373 # Fails on backward check if output size is 1x1 3374 DecorateInfo( 3375 unittest.expectedFailure, 3376 'TestModule', 3377 'test_memory_format', 3378 active_if=operator.itemgetter('training'), 3379 ),) 3380 ), 3381 ModuleInfo(torch.nn.AdaptiveAvgPool3d, 3382 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3383 module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d, 3384 skips=( 3385 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3386 # not supported on MPS backend 3387 DecorateInfo(skipMPS),) 3388 ), 3389 ModuleInfo(torch.nn.AdaptiveMaxPool1d, 3390 module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d, 3391 ), 3392 ModuleInfo(torch.nn.AdaptiveMaxPool2d, 3393 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3394 module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool2d, 3395 ), 3396 ModuleInfo(torch.nn.AdaptiveMaxPool3d, 3397 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3398 module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d, 3399 skips=( 3400 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3401 # not supported on MPS backend 3402 DecorateInfo(skipMPS),) 3403 ), 3404 ModuleInfo(torch.nn.AvgPool1d, 3405 module_inputs_func=module_inputs_torch_nn_AvgPool1d, 3406 ), 3407 ModuleInfo(torch.nn.AvgPool2d, 3408 module_inputs_func=module_inputs_torch_nn_AvgPool2d, 3409 skips=( 3410 # The difference between channels last backward and 3411 # channels first backward of AvgPool2d on CUDA is too large 3412 # See https://github.com/pytorch/pytorch/issues/107201 3413 DecorateInfo( 3414 unittest.expectedFailure, 3415 'TestModule', 3416 'test_memory_format', 3417 active_if=operator.itemgetter('training'), 3418 device_type='cuda', 3419 ), 3420 # error: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible 3421 DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float16]),), 3422 ), 3423 ModuleInfo(torch.nn.AvgPool3d, 3424 module_inputs_func=module_inputs_torch_nn_AvgPool3d, 3425 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3426 skips=( 3427 # No channels_last support for AvgPool1d as it does not take 4D inputs 3428 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3429 # not supported on MPS backend 3430 DecorateInfo(skipMPS),) 3431 ), 3432 ModuleInfo(torch.nn.BatchNorm1d, 3433 train_and_eval_differ=True, 3434 module_inputs_func=module_inputs_torch_nn_BatchNorm1d, 3435 skips=( 3436 # tracking here rather than in the list in test_aotdispatch.py as eval mode passes 3437 # RuntimeError: tried to get Double out of SymInt 3438 DecorateInfo( 3439 unittest.expectedFailure, 'TestEagerFusionModuleInfo', 3440 'test_aot_autograd_symbolic_module_exhaustive', 3441 active_if=operator.itemgetter('training') 3442 ), 3443 # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default 3444 DecorateInfo( 3445 unittest.expectedFailure, 'TestEagerFusionModuleInfo', 3446 'test_aot_autograd_module_exhaustive', 3447 active_if=operator.itemgetter('training') 3448 )) 3449 ), 3450 ModuleInfo(torch.nn.BatchNorm2d, 3451 train_and_eval_differ=True, 3452 module_inputs_func=module_inputs_torch_nn_BatchNorm2d, 3453 skips=( 3454 # See https://github.com/pytorch/pytorch/issues/134580 3455 DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')), 3456 # tracking here rather than in the list in test_aotdispatch.py as eval mode passes 3457 # RuntimeError: tried to get Double out of SymInt 3458 DecorateInfo( 3459 unittest.expectedFailure, 'TestEagerFusionModuleInfo', 3460 'test_aot_autograd_symbolic_module_exhaustive', 3461 active_if=operator.itemgetter('training') 3462 ), 3463 # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default 3464 DecorateInfo( 3465 unittest.expectedFailure, 'TestEagerFusionModuleInfo', 3466 'test_aot_autograd_module_exhaustive', 3467 active_if=operator.itemgetter('training') 3468 ),) 3469 ), 3470 ModuleInfo(torch.nn.BatchNorm3d, 3471 train_and_eval_differ=True, 3472 module_inputs_func=module_inputs_torch_nn_BatchNorm3d, 3473 skips=( 3474 # not supported on MPS backend 3475 DecorateInfo(skipMPS), 3476 # tracking here rather than in the list in test_aotdispatch.py as eval mode passes 3477 # RuntimeError: tried to get Double out of SymInt 3478 DecorateInfo( 3479 unittest.expectedFailure, 'TestEagerFusionModuleInfo', 3480 'test_aot_autograd_symbolic_module_exhaustive', 3481 active_if=operator.itemgetter('training') 3482 ), 3483 # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default 3484 DecorateInfo( 3485 unittest.expectedFailure, 'TestEagerFusionModuleInfo', 3486 'test_aot_autograd_module_exhaustive', 3487 active_if=operator.itemgetter('training') 3488 ),) 3489 ), 3490 ModuleInfo(torch.nn.CELU, 3491 module_inputs_func=module_inputs_torch_nn_CELU, 3492 # not MPS specific, will be xfailed for all devices in next PR 3493 skips=( 3494 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace', 3495 device_type='mps', dtypes=[torch.float16]),) 3496 ), 3497 ModuleInfo(torch.nn.Conv1d, 3498 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False), 3499 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3500 module_memformat_affects_out=True, 3501 skips=( 3502 # channels_last support on cuda requires cudnn >= 7603 3503 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3504 # Failure on ROCM for float32 issue #70125 3505 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3506 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3507 # xfail does not work due to Fatal Python error: Aborted 3508 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3509 device_type='mps', dtypes=[torch.float16]), 3510 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3511 device_type='mps', dtypes=[torch.float16]), 3512 ), 3513 decorators=( 3514 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3515 )), 3516 ModuleInfo(torch.nn.Conv2d, 3517 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False), 3518 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3519 module_memformat_affects_out=True, 3520 skips=( 3521 # channels_last support on cuda requires cudnn >= 7603 3522 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3523 # Failure on ROCM for float32 issue #70125 3524 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3525 # This was wrongly being skipped before and needs investigation. 3526 # See https://github.com/pytorch/pytorch/issues/80247 3527 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", 3528 device_type='cuda', dtypes=[torch.float64]), 3529 # Fails with channels last test on MPS backend 3530 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", 3531 device_type='mps', dtypes=[torch.float32]), 3532 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3533 # xfail does not work due to Fatal Python error: Aborted 3534 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3535 device_type='mps', dtypes=[torch.float16]), 3536 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3537 device_type='mps', dtypes=[torch.float16]), 3538 ), 3539 decorators=( 3540 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3541 )), 3542 ModuleInfo(torch.nn.Conv3d, 3543 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False), 3544 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3545 module_memformat_affects_out=True, 3546 skips=( 3547 # channels_last support on cuda requires cudnn >= 8005 3548 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'), 3549 # Failure on ROCM for float32 issue #70125 3550 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3551 # Conv3d is not supported on MPS backend 3552 DecorateInfo(skipMPS), 3553 # This was wrongly being skipped before and needs investigation. 3554 # See https://github.com/pytorch/pytorch/issues/80247 3555 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), 3556 ), 3557 decorators=( 3558 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3559 )), 3560 ModuleInfo(torch.nn.ConvTranspose1d, 3561 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True), 3562 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3563 module_memformat_affects_out=True, 3564 dtypes=floating_and_complex_types_and(torch.chalf), 3565 skips=( 3566 # channels_last support on cuda requires cudnn >= 7603 3567 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3568 # Failure on ROCM for float32 issue #70125 3569 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3570 # Not implmented for chalf on CPU 3571 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', 3572 dtypes=(torch.chalf,), device_type='cuda'), 3573 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3574 # xfail does not work due to Fatal Python error: Aborted 3575 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3576 device_type='mps', dtypes=[torch.float16]), 3577 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3578 device_type='mps', dtypes=[torch.float16]),), 3579 decorators=( 3580 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3581 DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'), 3582 )), 3583 ModuleInfo(torch.nn.ConvTranspose2d, 3584 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True), 3585 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3586 module_memformat_affects_out=True, 3587 dtypes=floating_and_complex_types_and(torch.chalf), 3588 skips=( 3589 # channels_last support on cuda requires cudnn >= 7603 3590 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3591 # Failure on ROCM for float32 issue #70125 3592 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3593 # Fails on backward check because ViewAsRealBackward apply contiguous for grad 3594 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format', 3595 dtypes=(torch.complex32, torch.complex64, torch.complex128)), 3596 # This was wrongly being skipped before and needs investigation. 3597 # See https://github.com/pytorch/pytorch/issues/80247 3598 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', 3599 dtypes=[torch.float64, torch.complex128]), 3600 # Fails with channels last test on MPS backend 3601 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", 3602 device_type='mps', dtypes=[torch.float32]), 3603 # Not implemented for chalf on CPU 3604 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', 3605 dtypes=(torch.chalf,), device_type='cuda'), 3606 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3607 # xfail does not work due to Fatal Python error: Aborted 3608 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3609 device_type='mps', dtypes=[torch.float16]), 3610 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3611 device_type='mps', dtypes=[torch.float16]), 3612 ), 3613 decorators=( 3614 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3615 DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'), 3616 )), 3617 ModuleInfo(torch.nn.ConvTranspose3d, 3618 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True), 3619 dtypes=floating_and_complex_types_and(torch.chalf), 3620 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3621 module_memformat_affects_out=True, 3622 skips=( 3623 # channels_last support on cuda requires cudnn >= 8005 3624 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'), 3625 # Failure on ROCM for float32 issue #70125 3626 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3627 # ConvTranspose3d is not supported on MPS backend 3628 DecorateInfo(skipMPS), 3629 # This was wrongly being skipped before and needs investigation. 3630 # See https://github.com/pytorch/pytorch/issues/80247 3631 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), 3632 # These fail only on ROCm 3633 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', 3634 dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM), 3635 # Not implmented for chalf on CPU 3636 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', 3637 dtypes=(torch.chalf,), device_type='cuda'), 3638 ), 3639 decorators=( 3640 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3641 DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'), 3642 DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'), 3643 )), 3644 ModuleInfo(torch.nn.CosineEmbeddingLoss, 3645 module_inputs_func=module_inputs_torch_nn_CosineEmbeddingLoss, 3646 skips=( 3647 # No channels_last support for loss functions. 3648 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3649 ), 3650 ModuleInfo(torch.nn.ELU, 3651 module_inputs_func=module_inputs_torch_nn_ELU, 3652 # not MPS specific, will be xfailed for all devices in next PR 3653 skips=( 3654 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace', 3655 device_type='mps', dtypes=[torch.float16]),) 3656 ), 3657 ModuleInfo(torch.nn.FractionalMaxPool2d, 3658 module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d, 3659 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3660 skips=( 3661 # not supported on MPS backend 3662 DecorateInfo(skipMPS), 3663 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3664 ), 3665 ModuleInfo(torch.nn.FractionalMaxPool3d, 3666 module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d, 3667 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3668 skips=( 3669 # not supported on MPS backend 3670 DecorateInfo(skipMPS), 3671 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3672 ), 3673 ModuleInfo(torch.nn.L1Loss, 3674 module_inputs_func=module_inputs_torch_nn_L1Loss, 3675 skips=( 3676 # No channels_last support for loss functions. 3677 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3678 ), 3679 ModuleInfo(torch.nn.SmoothL1Loss, 3680 module_inputs_func=module_inputs_torch_nn_SmoothL1Loss, 3681 skips=( 3682 # No channels_last support for loss functions. 3683 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3684 # See #119108: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible 3685 DecorateInfo(skipIfMps, 'TestModule', 'test_non_contiguous_tensors', dtypes=[torch.float16]),) 3686 ), 3687 ModuleInfo(torch.nn.LazyConv1d, 3688 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True), 3689 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3690 module_memformat_affects_out=True, 3691 skips=( 3692 # channels_last support on cuda requires cudnn >= 7603 3693 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3694 # Failure on ROCM for float32 issue #70125 3695 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3696 # Lazy modules don't currently play well with ModuleInfo tests on the meta device. 3697 # See https://github.com/pytorch/pytorch/issues/70505 for more info. 3698 DecorateInfo(skipMeta), 3699 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3700 # xfail does not work due to Fatal Python error: Aborted 3701 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3702 device_type='mps', dtypes=[torch.float16]), 3703 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3704 device_type='mps', dtypes=[torch.float16]), 3705 ), 3706 decorators=( 3707 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3708 )), 3709 ModuleInfo(torch.nn.LazyConv2d, 3710 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True), 3711 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3712 module_memformat_affects_out=True, 3713 skips=( 3714 # channels_last support on cuda requires cudnn >= 7603 3715 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3716 # Failure on ROCM for float32 issue #70125 3717 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3718 # Lazy modules don't currently play well with ModuleInfo tests on the meta device. 3719 # See https://github.com/pytorch/pytorch/issues/70505 for more info. 3720 DecorateInfo(skipMeta), 3721 # This was wrongly being skipped before and needs investigation. 3722 # See https://github.com/pytorch/pytorch/issues/80247 3723 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", 3724 device_type='cuda', dtypes=[torch.float64]), 3725 # Fails with channels last test on MPS backend 3726 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", 3727 device_type='mps', dtypes=[torch.float32]), 3728 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3729 # xfail does not work due to Fatal Python error: Aborted 3730 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3731 device_type='mps', dtypes=[torch.float16]), 3732 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3733 device_type='mps', dtypes=[torch.float16]), 3734 ), 3735 decorators=( 3736 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3737 )), 3738 ModuleInfo(torch.nn.LazyConv3d, 3739 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True), 3740 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3741 module_memformat_affects_out=True, 3742 skips=( 3743 # channels_last support on cuda requires cudnn >= 8005 3744 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'), 3745 # Failure on ROCM for float32 issue #70125 3746 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3747 # Lazy modules don't currently play well with ModuleInfo tests on the meta device. 3748 # See https://github.com/pytorch/pytorch/issues/70505 for more info. 3749 DecorateInfo(skipMeta), 3750 # LazyConv3d is not supported on MPS backend 3751 DecorateInfo(skipMPS), 3752 # This was wrongly being skipped before and needs investigation. 3753 # See https://github.com/pytorch/pytorch/issues/80247 3754 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), 3755 ), 3756 decorators=( 3757 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3758 )), 3759 ModuleInfo(torch.nn.LazyConvTranspose1d, 3760 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True), 3761 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3762 module_memformat_affects_out=True, 3763 skips=( 3764 # channels_last support on cuda requires cudnn >= 7603 3765 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3766 # Failure on ROCM for float32 issue #70125 3767 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3768 # Lazy modules don't currently play well with ModuleInfo tests on the meta device. 3769 # See https://github.com/pytorch/pytorch/issues/70505 for more info. 3770 DecorateInfo(skipMeta), 3771 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3772 # xfail does not work due to Fatal Python error: Aborted 3773 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3774 device_type='mps', dtypes=[torch.float16]), 3775 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3776 device_type='mps', dtypes=[torch.float16]), 3777 ), 3778 decorators=( 3779 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3780 )), 3781 ModuleInfo(torch.nn.LazyConvTranspose2d, 3782 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True), 3783 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3784 module_memformat_affects_out=True, 3785 skips=( 3786 # channels_last support on cuda requires cudnn >= 7603 3787 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'), 3788 # Failure on ROCM for float32 issue #70125 3789 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3790 # Lazy modules don't currently play well with ModuleInfo tests on the meta device. 3791 # See https://github.com/pytorch/pytorch/issues/70505 for more info. 3792 DecorateInfo(skipMeta), 3793 # This was wrongly being skipped before and needs investigation. 3794 # See https://github.com/pytorch/pytorch/issues/80247 3795 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda', 3796 dtypes=[torch.float64]), 3797 # Fails with channels last test on MPS backend 3798 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", 3799 device_type='mps', dtypes=[torch.float32]), 3800 # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32' 3801 # xfail does not work due to Fatal Python error: Aborted 3802 DecorateInfo(skipIfMps, "TestModule", "test_memory_format", 3803 device_type='mps', dtypes=[torch.float16]), 3804 DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors", 3805 device_type='mps', dtypes=[torch.float16]), 3806 ), 3807 decorators=( 3808 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3809 )), 3810 ModuleInfo(torch.nn.LazyConvTranspose3d, 3811 module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True), 3812 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3813 module_memformat_affects_out=True, 3814 skips=( 3815 # channels_last support on cuda requires cudnn >= 8005 3816 DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'), 3817 # Failure on ROCM for float32 issue #70125 3818 DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]), 3819 # Lazy modules don't currently play well with ModuleInfo tests on the meta device. 3820 # See https://github.com/pytorch/pytorch/issues/70505 for more info. 3821 DecorateInfo(skipMeta), 3822 # LazyConvTranspose3d is not supported on MPS backend 3823 DecorateInfo(skipMPS), 3824 # This was wrongly being skipped before and needs investigation. 3825 # See https://github.com/pytorch/pytorch/issues/80247 3826 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"), 3827 ), 3828 decorators=( 3829 DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), 3830 )), 3831 ModuleInfo(torch.nn.Linear, 3832 module_inputs_func=module_inputs_torch_nn_Linear, 3833 skips=( 3834 # No channels_last support for Linear currently. 3835 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3836 ), 3837 ModuleInfo(torch.nn.Bilinear, 3838 module_inputs_func=module_inputs_torch_nn_Bilinear, 3839 decorators=[ 3840 DecorateInfo( 3841 toleranceOverride({ 3842 torch.float32: tol(atol=1e-4, rtol=1e-4), 3843 torch.float64: tol(atol=1e-4, rtol=1e-4)}), 3844 'TestModule', 'test_forward', device_type='cpu'), 3845 ], 3846 skips=( 3847 # No channels_last support for Bilinear currently. 3848 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3849 # See #119108: tolerance issue 3850 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", 3851 device_type='mps', dtypes=[torch.float16]),) 3852 ), 3853 ModuleInfo(torch.nn.LPPool1d, 3854 module_inputs_func=module_inputs_torch_nn_LPPool1d, 3855 skips=( 3856 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), 3857 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) 3858 ), 3859 ModuleInfo(torch.nn.LPPool2d, 3860 module_inputs_func=module_inputs_torch_nn_LPPool2d, 3861 skips=( 3862 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), 3863 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'), 3864 # Fails on backward check on MPS 3865 # See https://github.com/pytorch/pytorch/issues/107214 3866 DecorateInfo( 3867 unittest.expectedFailure, 3868 'TestModule', 3869 'test_memory_format', 3870 active_if=operator.itemgetter('training'), 3871 device_type='mps', 3872 ),) 3873 ), 3874 ModuleInfo(torch.nn.LPPool3d, 3875 module_inputs_func=module_inputs_torch_nn_LPPool3d, 3876 skips=( 3877 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), 3878 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'), 3879 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3880 DecorateInfo(skipIfMps),) 3881 ), 3882 ModuleInfo(torch.nn.MaxPool1d, 3883 module_inputs_func=module_inputs_torch_nn_MaxPool1d, 3884 ), 3885 ModuleInfo(torch.nn.MaxPool2d, 3886 module_inputs_func=module_inputs_torch_nn_MaxPool2d, 3887 ), 3888 ModuleInfo(torch.nn.MaxPool3d, 3889 module_inputs_func=module_inputs_torch_nn_MaxPool3d, 3890 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 3891 skips=( 3892 # not supported on MPS backend 3893 DecorateInfo(skipMPS),) 3894 ), 3895 ModuleInfo(torch.nn.KLDivLoss, 3896 module_inputs_func=module_inputs_torch_nn_KLDivLoss, 3897 skips=( 3898 # No channels_last support for loss functions. 3899 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3900 # https://github.com/pytorch/pytorch/issues/115588 3901 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'), 3902 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), 3903 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) 3904 ), 3905 ModuleInfo(torch.nn.MSELoss, 3906 module_inputs_func=module_inputs_torch_nn_MSELoss, 3907 skips=( 3908 # No channels_last support for loss functions. 3909 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3910 # See #119108: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible 3911 DecorateInfo(skipIfMps, 'TestModule', 'test_non_contiguous_tensors', dtypes=[torch.float16]), 3912 # See #119108: tolerance issue 3913 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", 3914 device_type='mps', dtypes=[torch.float16]),) 3915 ), 3916 ModuleInfo(torch.nn.MarginRankingLoss, 3917 module_inputs_func=module_inputs_torch_nn_MarginRankingLoss, 3918 skips=( 3919 # No channels_last support for loss functions. 3920 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3921 ), 3922 ModuleInfo(torch.nn.MultiLabelMarginLoss, 3923 module_inputs_func=module_inputs_torch_nn_MultiLabelMarginLoss, 3924 skips=( 3925 # No channels_last support for loss functions. 3926 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3927 # 'aten::multilabel_margin_loss_forward' is not currently implemented for the MPS device. 3928 DecorateInfo(skipIfMps, 'TestModule'), 3929 # derivative for aten::multilabel_margin_loss_backward is not implemented 3930 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) 3931 ), 3932 ModuleInfo(torch.nn.MultiMarginLoss, 3933 module_inputs_func=module_inputs_torch_nn_MultiMarginLoss, 3934 skips=( 3935 # No channels_last support for loss functions. 3936 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3937 # 'aten::multi_margin_loss' is not currently implemented for the MPS device. 3938 DecorateInfo(skipIfMps, 'TestModule'), 3939 # RuntimeError: derivative for aten::multi_margin_loss_backward is not implemented 3940 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),) 3941 ), 3942 ModuleInfo(torch.nn.SoftMarginLoss, 3943 module_inputs_func=module_inputs_torch_nn_SoftMarginLoss, 3944 skips=( 3945 # No channels_last support for loss functions. 3946 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3947 # See #119108: tolerance issue 3948 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", 3949 device_type='mps', dtypes=[torch.float16]),) 3950 ), 3951 ModuleInfo(torch.nn.MultiLabelSoftMarginLoss, 3952 module_inputs_func=module_inputs_torch_nn_MultiLabelSoftMarginLoss, 3953 skips=( 3954 # No channels_last support for loss functions. 3955 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3956 ), 3957 ModuleInfo(torch.nn.NLLLoss, 3958 module_inputs_func=module_inputs_torch_nn_NLLLoss, 3959 skips=( 3960 # No channels_last support for loss functions. 3961 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3962 # See #119108: tolerance issue 3963 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", 3964 device_type='mps', dtypes=[torch.float16]),) 3965 ), 3966 ModuleInfo(torch.nn.GaussianNLLLoss, 3967 module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss, 3968 skips=( 3969 # No channels_last support for loss functions. 3970 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)), 3971 ModuleInfo(torch.nn.PoissonNLLLoss, 3972 module_inputs_func=module_inputs_torch_nn_PoissonNLLLoss, 3973 skips=( 3974 # No channels_last support for loss functions. 3975 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)), 3976 ModuleInfo(torch.nn.HingeEmbeddingLoss, 3977 module_inputs_func=module_inputs_torch_nn_HingeEmbeddingLoss, 3978 skips=( 3979 # No channels_last support for loss functions. 3980 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 3981 ), 3982 ModuleInfo(torch.nn.HuberLoss, 3983 module_inputs_func=module_inputs_torch_nn_HuberLoss, 3984 skips=( 3985 # No channels_last support for loss functions. 3986 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3987 # See #119108: seemingly incorrect output dtype 3988 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", 3989 device_type='mps', dtypes=[torch.float16]),) 3990 ), 3991 ModuleInfo(torch.nn.BCELoss, 3992 module_inputs_func=module_inputs_torch_nn_BCELoss, 3993 skips=( 3994 # No channels_last support for loss functions. 3995 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 3996 # error: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible 3997 DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float16]),) 3998 ), 3999 ModuleInfo(torch.nn.BCEWithLogitsLoss, 4000 module_inputs_func=module_inputs_torch_nn_BCEWithLogitsLoss, 4001 skips=( 4002 # No channels_last support for loss functions. 4003 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 4004 # see #119108: tolerance issue 4005 DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float16]),) 4006 ), 4007 ModuleInfo(torch.nn.CrossEntropyLoss, 4008 module_inputs_func=module_inputs_torch_nn_CrossEntropyLoss, 4009 dtypes=get_all_fp_dtypes(include_half=True, include_bfloat16=False), 4010 decorators=( 4011 # No channels_last support for loss functions. 4012 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'), 4013 DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule", 4014 "test_forward", dtypes=[torch.float16], device_type='cpu'), 4015 DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16], 4016 device_type='cuda'),), 4017 ), 4018 ModuleInfo(torch.nn.CTCLoss, 4019 module_inputs_func=module_inputs_torch_nn_CTCLoss, 4020 skips=( 4021 # No channels_last support for loss functions. 4022 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 4023 # The operator aten::_ctc_loss is not currently implemented for the MPS device. 4024 DecorateInfo(skipIfMps, 'TestModule'), 4025 # derivative for aten::_ctc_loss_backward is not implemented 4026 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'), 4027 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'), 4028 # https://github.com/pytorch/pytorch/issues/115585 4029 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_non_contiguous_tensors'),) 4030 ), 4031 ModuleInfo(torch.nn.GELU, 4032 module_inputs_func=module_inputs_torch_nn_GELU, 4033 skips=( 4034 # See #119108: tolerance issue 4035 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", 4036 device_type='mps', dtypes=[torch.float16]),) 4037 ), 4038 ModuleInfo(torch.nn.GLU, 4039 module_inputs_func=module_inputs_torch_nn_GLU, 4040 ), 4041 ModuleInfo(torch.nn.GroupNorm, 4042 module_inputs_func=module_inputs_torch_nn_GroupNorm, 4043 dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True), 4044 skips=( 4045 # Tracking at https://github.com/pytorch/pytorch/issues/98089 4046 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'), 4047 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), 4048 'TestModule', 'test_memory_format', device_type='cpu'), 4049 # No channels_last support for GroupNorm currently. 4050 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cuda'), 4051 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='mps'), 4052 DecorateInfo(unittest.skip("Skipped!"), "TestModule", "test_grad", 4053 active_if=TEST_WITH_ROCM, device_type='cuda'),) 4054 ), 4055 ModuleInfo(torch.nn.Hardshrink, 4056 module_inputs_func=module_inputs_torch_nn_Hardshrink, 4057 skips=( 4058 # not supported on MPS backend 4059 DecorateInfo(skipMPS),), 4060 ), 4061 ModuleInfo(torch.nn.Hardswish, 4062 module_inputs_func=module_inputs_torch_nn_Hardswish, 4063 skips=None if _macos15_or_newer else ( 4064 # Fails on backward check on MPS 4065 # See https://github.com/pytorch/pytorch/issues/107214 4066 DecorateInfo( 4067 unittest.expectedFailure, 4068 'TestModule', 4069 'test_memory_format', 4070 active_if=operator.itemgetter('training'), 4071 device_type='mps', 4072 ),), 4073 supports_gradgrad=False), 4074 ModuleInfo(torch.nn.Hardtanh, 4075 module_inputs_func=module_inputs_torch_nn_Hardtanh, 4076 ), 4077 ModuleInfo(torch.nn.InstanceNorm1d, 4078 module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=1), 4079 train_and_eval_differ=True, 4080 skips=( 4081 # No channels_last support for InstanceNorm1d currently. 4082 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4083 ), 4084 ModuleInfo(torch.nn.InstanceNorm2d, 4085 module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=2), 4086 train_and_eval_differ=True, 4087 skips=( 4088 # No channels_last support for InstanceNorm2d currently. 4089 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4090 ), 4091 ModuleInfo(torch.nn.InstanceNorm3d, 4092 module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3), 4093 train_and_eval_differ=True, 4094 skips=( 4095 # not supported on MPS backend 4096 DecorateInfo(skipMPS), 4097 # No channels_last support for InstanceNorm3d currently. 4098 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4099 ), 4100 ModuleInfo(torch.nn.LocalResponseNorm, 4101 module_inputs_func=module_inputs_torch_nn_LocalResponseNorm, 4102 skips=( 4103 # uses avg_pool3d which is not supported on MPS backend 4104 DecorateInfo(skipMPS),) 4105 ), 4106 ModuleInfo(torch.nn.LayerNorm, 4107 module_inputs_func=module_inputs_torch_nn_LayerNorm, 4108 skips=( 4109 # No channels_last support for LayerNorm currently. 4110 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4111 ), 4112 ModuleInfo(torch.nn.RMSNorm, 4113 module_inputs_func=module_inputs_torch_nn_RMSNorm, 4114 ), 4115 # TransformerEncoder takes the same inputs as TransformerEncoderLayer 4116 ModuleInfo(torch.nn.TransformerEncoder, 4117 train_and_eval_differ=True, 4118 module_inputs_func=module_inputs_torch_nn_TransformerEncoder, 4119 decorators=[ 4120 # Not implemented for SDPA backward derivative 4121 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', 4122 device_type='cpu'), 4123 ], 4124 skips=( 4125 # No channels_last support for TransformerEncoderLayer currently. 4126 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 4127 # Doesn't support device / dtype kwargs directly because it is just a 4128 # container of TransformerEncoderLayers. 4129 DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_factory_kwargs'),) 4130 ), 4131 ModuleInfo(torch.nn.TransformerEncoderLayer, 4132 train_and_eval_differ=True, 4133 module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer, 4134 decorators=[ 4135 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), 4136 'TestModule', 'test_non_contiguous_tensors', 4137 device_type='cpu', active_if=IS_WINDOWS), 4138 # Not implemented for SDPA backward derivative 4139 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', 4140 device_type='cpu'), 4141 ], 4142 skips=( 4143 # No channels_last support for TransformerEncoderLayer currently. 4144 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4145 ), 4146 ModuleInfo(torch.nn.TransformerDecoderLayer, 4147 module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer, 4148 decorators=[ 4149 # Not implemented for SDPA backward derivative 4150 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', 4151 device_type='cpu'), 4152 ], 4153 skips=( 4154 # No channels_last support for TransformerDecoderLayer currently. 4155 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4156 ), 4157 ModuleInfo(torch.nn.Transformer, 4158 module_inputs_func=module_inputs_torch_nn_Transformer, 4159 decorators=[ 4160 # Not implemented for SDPA backward derivative 4161 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad', 4162 device_type='cpu'), 4163 ], 4164 skips=( 4165 # No channels_last support for Transformer currently. 4166 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4167 ), 4168 ModuleInfo(torch.nn.MultiheadAttention, 4169 train_and_eval_differ=True, 4170 module_inputs_func=module_inputs_torch_nn_MultiheadAttention, 4171 skips=( 4172 # No channels_last support for MultiheadAttention currently. 4173 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4174 ), 4175 ModuleInfo(torch.nn.Embedding, 4176 module_inputs_func=module_inputs_torch_nn_Embedding, 4177 decorators=[ 4178 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), 4179 'TestModule', 'test_non_contiguous_tensors', 4180 device_type='mps')], 4181 skips=( 4182 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4183 ), 4184 ModuleInfo(torch.nn.ReLU, 4185 module_inputs_func=module_inputs_torch_nn_ReLU, 4186 skips=None if _macos15_or_newer else ( 4187 # Fails on backward check on MPS 4188 # See https://github.com/pytorch/pytorch/issues/107214 4189 DecorateInfo( 4190 unittest.expectedFailure, 4191 'TestModule', 4192 'test_memory_format', 4193 active_if=operator.itemgetter('training'), 4194 device_type='mps', 4195 ),) 4196 ), 4197 ModuleInfo(torch.nn.LeakyReLU, 4198 module_inputs_func=module_inputs_torch_nn_LeakyReLU, 4199 ), 4200 ModuleInfo(torch.nn.ReLU6, 4201 module_inputs_func=module_inputs_torch_nn_ReLU6, 4202 skips=( 4203 # test fails on MPS backend and is being investigated. 4204 # See https://github.com/pytorch/pytorch/issues/100914 4205 DecorateInfo(skipMPS),) 4206 ), 4207 ModuleInfo(torch.nn.PReLU, 4208 module_inputs_func=module_inputs_torch_nn_PReLU, 4209 skips=( 4210 # test fails on MPS backend and is being investigated. 4211 # See https://github.com/pytorch/pytorch/issues/100914 4212 DecorateInfo(skipMPS),) 4213 ), 4214 ModuleInfo(torch.nn.RNNCell, 4215 module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True), 4216 module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell, 4217 ), 4218 ModuleInfo(torch.nn.GRUCell, 4219 module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell, 4220 module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell, 4221 ), 4222 ModuleInfo(torch.nn.LSTMCell, 4223 module_inputs_func=module_inputs_torch_nn_LSTMCell, 4224 module_error_inputs_func=module_error_inputs_torch_nn_LSTMCell, 4225 ), 4226 ModuleInfo(torch.nn.Sigmoid, 4227 module_inputs_func=module_inputs_torch_nn_Sigmoid, 4228 skips=None if _macos15_or_newer else ( 4229 # Fails on backward check on MPS 4230 # See https://github.com/pytorch/pytorch/issues/107214 4231 DecorateInfo( 4232 unittest.expectedFailure, 4233 'TestModule', 4234 'test_memory_format', 4235 active_if=operator.itemgetter('training'), 4236 device_type='mps', 4237 ),) 4238 ), 4239 ModuleInfo(torch.nn.LogSigmoid, 4240 module_inputs_func=module_inputs_torch_nn_LogSigmoid, 4241 skips=( 4242 # See #119108: tolerance issue 4243 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),) 4244 ), 4245 ModuleInfo(torch.nn.SiLU, 4246 module_inputs_func=module_inputs_torch_nn_SiLU, 4247 ), 4248 ModuleInfo(torch.nn.Softmax, 4249 module_inputs_func=module_inputs_torch_nn_Softmax, 4250 ), 4251 ModuleInfo(torch.nn.Softmax2d, 4252 module_inputs_func=module_inputs_torch_nn_Softmax2d, 4253 skips=( 4254 # no channels last support for Softmax2d currently 4255 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 4256 # See #119108: tolerance issue 4257 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),) 4258 ), 4259 ModuleInfo(torch.nn.LogSoftmax, 4260 module_inputs_func=module_inputs_torch_nn_LogSoftmax, 4261 skips=( 4262 # no channels last support for LogSoftmax currently 4263 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), 4264 # See #119108: inf nan error 4265 DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),) 4266 ), 4267 ModuleInfo(torch.nn.Softmin, 4268 module_inputs_func=module_inputs_torch_nn_Softmin, 4269 skips=( 4270 # no channels last support for Softmin currently 4271 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),) 4272 ), 4273 ModuleInfo(torch.nn.Softplus, 4274 module_inputs_func=module_inputs_torch_nn_Softplus, 4275 skips=( 4276 # test fails on MPS backend and is being investigated. 4277 # See https://github.com/pytorch/pytorch/issues/100914 4278 DecorateInfo(skipMPS),) 4279 ), 4280 ModuleInfo(torch.nn.Softshrink, 4281 module_inputs_func=module_inputs_torch_nn_Softshrink, 4282 skips=( 4283 # not supported on MPS backend 4284 DecorateInfo(skipMPS),) 4285 ), 4286 ModuleInfo(torch.nn.Softsign, 4287 module_inputs_func=module_inputs_torch_nn_Softsign, 4288 ), 4289 ModuleInfo(torch.nn.Tanh, 4290 module_inputs_func=module_inputs_torch_nn_Tanh, 4291 skips=None if _macos15_or_newer else ( 4292 # Fails on backward check on MPS 4293 # See https://github.com/pytorch/pytorch/issues/107214 4294 DecorateInfo( 4295 unittest.expectedFailure, 4296 'TestModule', 4297 'test_memory_format', 4298 active_if=operator.itemgetter('training'), 4299 device_type='mps', 4300 ),) 4301 ), 4302 ModuleInfo(torch.nn.Tanhshrink, 4303 module_inputs_func=module_inputs_torch_nn_Tanhshrink, 4304 skips=None if _macos15_or_newer else ( 4305 # Fails on backward check on MPS 4306 # See https://github.com/pytorch/pytorch/issues/107214 4307 DecorateInfo( 4308 unittest.expectedFailure, 4309 'TestModule', 4310 'test_memory_format', 4311 active_if=operator.itemgetter('training'), 4312 device_type='mps', 4313 ),) 4314 ), 4315 ModuleInfo(torch.nn.Threshold, 4316 module_inputs_func=module_inputs_torch_nn_Threshold, 4317 skips=( 4318 # test fails on MPS backend and is being investigated. 4319 # See https://github.com/pytorch/pytorch/issues/100914 4320 DecorateInfo(skipMPS),) 4321 ), 4322 ModuleInfo(torch.nn.Mish, 4323 module_inputs_func=module_inputs_torch_nn_Mish, 4324 skips=( 4325 # not supported on MPS backend 4326 DecorateInfo(skipMPS),) 4327 ), 4328 ModuleInfo(torch.nn.RNN, 4329 train_and_eval_differ=True, 4330 module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True), 4331 module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU, 4332 decorators=rnn_gru_lstm_module_info_decorators 4333 ), 4334 ModuleInfo(torch.nn.GRU, 4335 train_and_eval_differ=True, 4336 module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False), 4337 module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU, 4338 decorators=rnn_gru_lstm_module_info_decorators), 4339 ModuleInfo(torch.nn.LSTM, 4340 train_and_eval_differ=True, 4341 module_inputs_func=module_inputs_torch_nn_LSTM, 4342 module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU, 4343 skips=( 4344 # LSTM with projections is not currently supported with MPS 4345 DecorateInfo(skipMPS),), 4346 decorators=rnn_gru_lstm_module_info_decorators), 4347 ModuleInfo(torch.nn.ReflectionPad1d, 4348 module_inputs_func=module_inputs_torch_nn_ReflectionPad1d, 4349 ), 4350 ModuleInfo(torch.nn.ReflectionPad2d, 4351 module_inputs_func=module_inputs_torch_nn_ReflectionPad2d, 4352 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 4353 skips=( 4354 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4355 device_type='cuda'), 4356 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4357 device_type='mps'),) 4358 ), 4359 ModuleInfo(torch.nn.ReflectionPad3d, 4360 module_inputs_func=module_inputs_torch_nn_ReflectionPad3d, 4361 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 4362 skips=( 4363 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4364 device_type='cuda'), 4365 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4366 device_type='mps'),) 4367 ), 4368 ModuleInfo(torch.nn.ReplicationPad1d, 4369 module_inputs_func=module_inputs_torch_nn_ReplicationPad1d, 4370 ), 4371 ModuleInfo(torch.nn.ReplicationPad2d, 4372 module_inputs_func=module_inputs_torch_nn_ReplicationPad2d, 4373 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 4374 skips=( 4375 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4376 device_type='cuda'), 4377 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4378 device_type='mps'),) 4379 ), 4380 ModuleInfo(torch.nn.ReplicationPad3d, 4381 module_inputs_func=module_inputs_torch_nn_ReplicationPad3d, 4382 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 4383 skips=( 4384 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4385 device_type='cuda'), 4386 DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', 4387 device_type='mps'),) 4388 ), 4389 ModuleInfo(torch.nn.SELU, 4390 module_inputs_func=module_inputs_torch_nn_SELU, 4391 skips=( 4392 # test fails on MPS backend and is being investigated. 4393 # See https://github.com/pytorch/pytorch/issues/100914 4394 DecorateInfo(skipMPS),) 4395 ), 4396 ModuleInfo(torch.nn.ZeroPad1d, 4397 module_inputs_func=module_inputs_torch_nn_ZeroPad1d, 4398 ), 4399 ModuleInfo(torch.nn.ZeroPad2d, 4400 module_inputs_func=module_inputs_torch_nn_ZeroPad2d, 4401 skips=( 4402 # Fails with channels last test on MPS backend 4403 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) 4404 ), 4405 ModuleInfo(torch.nn.ZeroPad3d, 4406 module_inputs_func=module_inputs_torch_nn_ZeroPad3d, 4407 skips=( 4408 # Fails with channels last test on MPS backend 4409 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) 4410 ), 4411 ModuleInfo(torch.nn.CircularPad1d, 4412 module_inputs_func=module_inputs_torch_nn_CircularPad1d, 4413 module_error_inputs_func=module_error_inputs_torch_nn_Pad1d, 4414 ), 4415 ModuleInfo(torch.nn.CircularPad2d, 4416 module_inputs_func=module_inputs_torch_nn_CircularPad2d, 4417 module_error_inputs_func=module_error_inputs_torch_nn_Pad2d, 4418 ), 4419 ModuleInfo(torch.nn.CircularPad3d, 4420 module_inputs_func=module_inputs_torch_nn_CircularPad3d, 4421 module_error_inputs_func=module_error_inputs_torch_nn_Pad3d, 4422 skips=( 4423 # Fails with channels last test on MPS backend 4424 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),) 4425 ), 4426 ModuleInfo(torch.nn.ConstantPad1d, 4427 module_inputs_func=module_inputs_torch_nn_ConstantPad1d, 4428 ), 4429 ModuleInfo(torch.nn.ConstantPad2d, 4430 module_inputs_func=module_inputs_torch_nn_ConstantPad2d, 4431 skips=( 4432 # Fails with channels last test on MPS backend 4433 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) 4434 ), 4435 ModuleInfo(torch.nn.ConstantPad3d, 4436 module_inputs_func=module_inputs_torch_nn_ConstantPad3d, 4437 skips=( 4438 # Fails with channels last test on MPS backend 4439 DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),) 4440 ) 4441] 4442