xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4import unittest
5from copy import deepcopy
6from enum import Enum
7from functools import wraps, partial
8from itertools import chain, product
9import itertools
10import math
11import torch.nn.functional as F
12from torch.nn.utils.rnn import pack_padded_sequence
13from torch.testing import make_tensor
14from torch.testing._internal.common_cuda import TEST_CUDNN
15from torch.testing._internal.common_dtype import (
16    floating_types, floating_and_complex_types_and, get_all_fp_dtypes)
17from torch.testing._internal.common_device_type import (
18    _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol,
19    skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS,
20    skipCUDAVersionIn)
21from torch.testing._internal.common_methods_invocations import DecorateInfo
22from torch.testing._internal.common_nn import (
23    cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference,
24    hingeembeddingloss_reference, huberloss_reference, kldivloss_reference,
25    marginrankingloss_reference, multimarginloss_reference, multilabelmarginloss_reference,
26    nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction)
27from torch.testing._internal.common_utils import (
28    freeze_rng_state, skipIfMps, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS,
29    skipIfTorchDynamo)
30from types import ModuleType
31from typing import List, Tuple, Type, Set, Dict
32import operator
33
34# List of all namespaces containing modules to test.
35MODULE_NAMESPACES: List[ModuleType] = [
36    torch.nn.modules,
37    torch.ao.nn.qat.modules,
38    torch.ao.nn.quantizable.modules,
39    torch.ao.nn.quantized.modules,
40    torch.ao.nn.quantized.modules,
41]
42
43# Modules that shouldn't be tested for one reason or another.
44MODULES_TO_SKIP: Set[Type] = {
45    torch.nn.Module,  # abstract base class
46    torch.nn.Container,  # deprecated
47    torch.nn.NLLLoss2d,  # deprecated
48    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
49    torch.ao.nn.quantized.MaxPool2d,  # aliases to nn.MaxPool2d
50}
51
52# List of all module classes to test.
53MODULE_CLASSES: List[Type] = list(chain(*[
54    [getattr(namespace, module_name) for module_name in namespace.__all__]  # type: ignore[attr-defined]
55    for namespace in MODULE_NAMESPACES]))
56MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP]
57
58# Dict of module class -> common name. Useful for making test names more intuitive.
59# Example: torch.nn.modules.linear.Linear -> "nn.Linear"
60MODULE_CLASS_NAMES: Dict[Type, str] = {}
61for namespace in MODULE_NAMESPACES:
62    for module_name in namespace.__all__:  # type: ignore[attr-defined]
63        module_cls = getattr(namespace, module_name)
64        namespace_name = namespace.__name__.replace('torch.', '').replace('.modules', '')
65
66        # Deal with any aliases by preferring earlier names.
67        if module_cls not in MODULE_CLASS_NAMES:
68            MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}'
69
70
71# Specifies the modes (i.e. train, eval) to test over.
72TrainEvalMode = Enum('TrainEvalMode', ('train_only', 'eval_only', 'train_and_eval'))
73
74
75class modules(_TestParametrizer):
76    """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
77
78    def __init__(self, module_info_iterable, allowed_dtypes=None,
79                 train_eval_mode=TrainEvalMode.train_and_eval, skip_if_dynamo=True):
80        self.module_info_list = list(module_info_iterable)
81        self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
82        self.train_eval_mode = train_eval_mode
83        self.skip_if_dynamo = skip_if_dynamo
84
85    def _get_training_flags(self, module_info):
86        training_flags = []
87        if (self.train_eval_mode == TrainEvalMode.train_only or
88                self.train_eval_mode == TrainEvalMode.train_and_eval):
89            training_flags.append(True)
90
91        if (self.train_eval_mode == TrainEvalMode.eval_only or
92                self.train_eval_mode == TrainEvalMode.train_and_eval):
93            training_flags.append(False)
94
95        # If train and eval modes don't differ for the module, don't bother using more than one.
96        if not module_info.train_and_eval_differ:
97            training_flags = training_flags[:1]
98
99        return training_flags
100
101    def _parametrize_test(self, test, generic_cls, device_cls):
102        if device_cls is None:
103            raise RuntimeError('The @modules decorator is only intended to be used in a device-specific '
104                               'context; use it with instantiate_device_type_tests() instead of '
105                               'instantiate_parametrized_tests()')
106
107        for module_info in self.module_info_list:
108            dtypes = set(module_info.supported_dtypes(device_cls.device_type))
109            if self.allowed_dtypes is not None:
110                dtypes = dtypes.intersection(self.allowed_dtypes)
111
112            training_flags = self._get_training_flags(module_info)
113            for (training, dtype) in product(training_flags, dtypes):
114                # Construct the test name; device / dtype parts are handled outside.
115                # See [Note: device and dtype suffix placement]
116                test_name = module_info.formatted_name
117                if len(training_flags) > 1:
118                    test_name += f"_{'train_mode' if training else 'eval_mode'}"
119
120                # Construct parameter kwargs to pass to the test.
121                param_kwargs = {'module_info': module_info}
122                _update_param_kwargs(param_kwargs, 'dtype', dtype)
123                _update_param_kwargs(param_kwargs, 'training', training)
124
125                try:
126
127                    @wraps(test)
128                    def test_wrapper(*args, **kwargs):
129                        return test(*args, **kwargs)
130
131                    if self.skip_if_dynamo and not torch.testing._internal.common_utils.TEST_WITH_TORCHINDUCTOR:
132                        test_wrapper = skipIfTorchDynamo("Policy: we don't run ModuleInfo tests w/ Dynamo")(test_wrapper)
133
134                    decorator_fn = partial(module_info.get_decorators, generic_cls.__name__,
135                                           test.__name__, device_cls.device_type, dtype)
136
137                    yield (test_wrapper, test_name, param_kwargs, decorator_fn)
138                except Exception as ex:
139                    # Provides an error message for debugging before rethrowing the exception
140                    print(f"Failed to instantiate {test_name} for module {module_info.name}!")
141                    raise ex
142
143
144def get_module_common_name(module_cls):
145    if module_cls in MODULE_CLASS_NAMES:
146        # Example: "nn.Linear"
147        return MODULE_CLASS_NAMES[module_cls]
148    else:
149        return module_cls.__name__
150
151
152class FunctionInput:
153    """ Contains args and kwargs to pass as input to a function. """
154    __slots__ = ['args', 'kwargs']
155
156    def __init__(self, *args, **kwargs):
157        self.args = args
158        self.kwargs = kwargs
159
160
161class ModuleInput:
162    """ Contains args / kwargs for module instantiation + forward pass. """
163    __slots__ = ['constructor_input', 'forward_input', 'desc', 'reference_fn']
164
165    def __init__(self, constructor_input, forward_input=None, desc='', reference_fn=None):
166        self.constructor_input = constructor_input  # Inputs to pass during construction
167        self.forward_input = forward_input  # Inputs to pass to forward()
168        self.desc = desc  # Description for this set of inputs
169        self.reference_fn = reference_fn  # Reference with signature: reference_fn(module, parameters, *args, **kwargs)
170
171        if reference_fn is not None:
172
173            @wraps(reference_fn)
174            def copy_reference_fn(m, *args, **kwargs):
175                # Copy inputs to avoid undesired side effects from calling the reference.
176                args, kwargs = deepcopy(args), deepcopy(kwargs)
177
178                # Note that module parameters are passed in for convenience.
179                return reference_fn(m, list(m.parameters()), *args, **kwargs)
180
181            self.reference_fn = copy_reference_fn
182
183class ModuleErrorEnum(Enum):
184    """ Enumerates when error is raised when testing modules. """
185    CONSTRUCTION_ERROR = 0
186    FORWARD_ERROR = 1
187
188class ErrorModuleInput:
189    """
190    A ModuleInput that will cause the operation to throw an error plus information
191    about the resulting error.
192    """
193
194    __slots__ = ["module_error_input", "error_on", "error_type", "error_regex"]
195
196    def __init__(self,
197                 module_error_input,
198                 *,
199                 error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
200                 error_type=RuntimeError,
201                 error_regex):
202        self.module_error_input = module_error_input
203        self.error_on = error_on
204        self.error_type = error_type
205        self.error_regex = error_regex
206
207
208class ModuleInfo:
209    """ Module information to be used in testing. """
210
211    def __init__(self,
212                 module_cls,  # Class object for the module under test
213                 *,
214                 module_inputs_func,  # Function to generate module inputs
215                 skips=(),  # Indicates which tests to skip
216                 decorators=None,  # Additional decorators to apply to generated tests
217                 dtypes=floating_types(),  # dtypes this function is expected to work with
218                 dtypesIfMPS=(torch.float16, torch.float32,),  # dtypes this function is expected to work with on MPS
219                 dtypesIfHpu=(torch.bfloat16, torch.float32,),
220                 supports_gradgrad=True,  # whether the op supports second order gradients
221                 gradcheck_nondet_tol=0.0,  # tolerance for nondeterminism while performing gradcheck
222                 module_memformat_affects_out=False,  # whether converting module to channels last will generate
223                                                      # channels last output
224                 train_and_eval_differ=False,  # whether the module has differing behavior between train and eval
225                 module_error_inputs_func=None,  # Function to generate module inputs that error
226                 ):
227        self.module_cls = module_cls
228        self.module_inputs_func = module_inputs_func
229        self.decorators = (*(decorators if decorators else []), *(skips if skips else []))
230        self.dtypes = dtypes
231        self.dtypesIfMPS = dtypesIfMPS
232        self.dtypesIfHpu = dtypesIfHpu
233        self.supports_gradgrad = supports_gradgrad
234        self.gradcheck_nondet_tol = gradcheck_nondet_tol
235        self.module_memformat_affects_out = module_memformat_affects_out
236        self.train_and_eval_differ = train_and_eval_differ
237        self.module_error_inputs_func = module_error_inputs_func
238        self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin)
239
240    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
241        result = []
242        for decorator in self.decorators:
243            if isinstance(decorator, DecorateInfo):
244                if decorator.is_active(test_class, test_name, device, dtype, param_kwargs):
245                    result.extend(decorator.decorators)
246            else:
247                result.append(decorator)
248        return result
249
250    def supported_dtypes(self, device_type):
251        if device_type == 'mps':
252            return self.dtypesIfMPS
253        elif device_type == 'hpu':
254            return self.dtypesIfHpu
255        else:
256            return self.dtypes
257
258    @property
259    def name(self):
260        return get_module_common_name(self.module_cls)
261
262    @property
263    def formatted_name(self):
264        return self.name.replace('.', '_')
265
266# Start of module inputs functions.
267
268def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
269    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
270
271    module_inputs = [
272        ModuleInput(constructor_input=FunctionInput(10, 8),
273                    forward_input=FunctionInput(input=make_input((4, 10))),
274                    reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
275        ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
276                    forward_input=FunctionInput(make_input((4, 10))),
277                    desc='no_bias',
278                    reference_fn=lambda m, p, i: torch.mm(i, p[0].t())),
279        ModuleInput(constructor_input=FunctionInput(3, 5),
280                    forward_input=FunctionInput(make_input(3)),
281                    desc='no_batch_dim',
282                    reference_fn=lambda m, p, i: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1])
283    ]
284
285    return module_inputs
286
287
288def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, training, **kwargs):
289    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
290
291    def bilinear_reference_fn(m, p, x1, x2, bias=True):
292        result = torch.einsum('bn,anm,bm->ba', x1, p[0], x2)
293        if bias:
294            if x1.shape[0] == 1:
295                result = result.view(-1) + p[1]
296            else:
297                result = result + p[1].view(1, -1).expand(x1.shape[0], p[0].shape[0])
298        return result
299
300    module_inputs = [
301        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
302                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
303                    reference_fn=bilinear_reference_fn),
304        ModuleInput(constructor_input=FunctionInput(2, 3, 4, bias=False),
305                    forward_input=FunctionInput(make_input((8, 2)), make_input((8, 3))),
306                    desc='no_bias',
307                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1, x2, bias=False)),
308        ModuleInput(constructor_input=FunctionInput(2, 3, 4),
309                    forward_input=FunctionInput(make_input(2), make_input(3)),
310                    desc='no_batch_dim',
311                    reference_fn=lambda m, p, x1, x2: bilinear_reference_fn(m, p, x1.view(1, -1), x2.view(1, -1))),
312    ]
313
314    return module_inputs
315
316
317def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs):
318    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
319
320    cases: List[Tuple[str, dict]] = [
321        ('', {}),
322        ('reduction_sum', {'reduction': 'sum'}),
323        ('reduction_batchmean', {'reduction': 'batchmean'}),
324        ('reduction_none', {'reduction': 'none'}),
325        ('log_target', {'log_target': True})
326    ]
327
328    module_inputs = []
329    for desc, constructor_kwargs in cases:
330        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
331            return kldivloss_reference(i, t, **constructor_kwargs)
332
333        input = make_input((10, 10)).log()
334        target = make_input((10, 10)) if kwargs.get('log_target', False) else make_input((10, 10)).log()
335        module_inputs.append(
336            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
337                        forward_input=FunctionInput(input, target),
338                        desc=desc,
339                        reference_fn=reference_fn)
340        )
341
342        scalar_input = make_input(()).log()
343        scalar_target = make_input(()) if kwargs.get('log_target', False) else make_input(()).log()
344        module_inputs.append(
345            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
346                        forward_input=FunctionInput(scalar_input, scalar_input),
347                        desc='scalar_' + desc,
348                        reference_fn=reference_fn)
349        )
350
351    return module_inputs
352
353
354def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
355    def make_input(shape, device=device, dtype=dtype, requires_grad=requires_grad):
356        return make_tensor(shape, device=device, dtype=dtype,
357                           requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad)
358    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
359
360    cases: List[Tuple[str, dict]] = [
361        ('', {}),
362        ('reduction_sum', {'reduction': 'sum'}),
363        ('reduction_none', {'reduction': 'none'}),
364        ('ignore_index', {'ignore_index': 2}),
365        ('weights', {'weight': make_weight(4).abs()}),
366        ('weights_ignore_index', {'weight': make_weight(4).abs(), 'ignore_index': 2}),
367        ('weights_ignore_index_neg', {'weight': make_weight(4).abs(), 'ignore_index': -1})
368    ]
369
370    # TODO: Uncomment when negative weights is supported.
371    # negative_weight = make_weight(10)
372    # negative_weight[0] = -1
373    # cases.append(('weights_negative', {'weight': negative_weight}))
374    module_inputs = []
375    for desc, constructor_kwargs in cases:
376
377        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
378            return nllloss_reference(i, t, **constructor_kwargs)
379
380        module_inputs.append(
381            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
382                        forward_input=FunctionInput(make_input((15, 4)),
383                                                    torch.empty(15, device=device).uniform_().mul(4).floor().long()),
384                        desc=desc,
385                        reference_fn=reference_fn)
386        )
387
388        def nd_reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
389            return nlllossNd_reference(i, t, **constructor_kwargs)
390
391        module_inputs.append(
392            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
393                        forward_input=FunctionInput(
394                            make_input((2, 4, 5, 5)),
395                            torch.empty(2, 5, 5, device=device).uniform_().mul(4).floor().long()),
396                        desc=f"nd_{desc}",
397                        reference_fn=nd_reference_fn)
398        )
399
400        module_inputs.append(
401            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
402                        forward_input=FunctionInput(
403                            make_input((2, 4, 5, 5, 2, 2)),
404                            torch.empty(2, 5, 5, 2, 2, device=device).uniform_().mul(4).floor().long()),
405                        desc=f"higher_dim_{desc}",
406                        reference_fn=nd_reference_fn)
407        )
408
409        module_inputs.append(
410            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
411                        forward_input=FunctionInput(
412                            make_input((2, 4, 5)),
413                            torch.empty(2, 5, device=device).uniform_().mul(4).floor().long()),
414                        desc=f"3d_{desc}",
415                        reference_fn=nd_reference_fn)
416        )
417
418    return module_inputs
419
420
421def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
422    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
423    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
424
425    cases: List[Tuple[str, dict]] = [
426        ('', {}),
427        ('reduction_sum', {'reduction': 'sum'}),
428        ('reduction_mean', {'reduction': 'mean'}),
429        ('reduction_none', {'reduction': 'none'}),
430    ]
431
432    module_inputs = []
433    for desc, constructor_kwargs in cases:
434        module_inputs.append(
435            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
436                        forward_input=FunctionInput(make_input(3),
437                                                    make_target(3),
438                                                    make_input(1).abs()),
439                        desc=desc,
440                        reference_fn=no_batch_dim_reference_fn)
441        )
442
443    return module_inputs
444
445
446def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_grad, training, **kwargs):
447    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
448    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
449
450    cases: List[Tuple[str, dict]] = [
451        ('', {}),
452        ('reduction_sum', {'reduction': 'sum'}),
453        ('reduction_mean', {'reduction': 'mean'}),
454        ('reduction_none', {'reduction': 'none'}),
455        ('full', {'full': True}),
456        ('no_log_input', {'log_input': False}),
457        ('full_no_log_input', {'full': True, 'log_input': False}),
458    ]
459
460    def poissonnllloss_reference_fn(i, t, log_input=True, full=False, reduction='mean', eps=1e-8):
461        if log_input:
462            result = i.exp() - t.mul(i)
463        else:
464            result = i - t.mul((i + eps).log())
465
466        if full:
467            result += (t.mul(t.log()) - t + 0.5 * (2. * math.pi * t).log()).masked_fill(t <= 1, 0)
468
469        if reduction == 'none':
470            return result
471        elif reduction == 'mean':
472            return result.sum() / i.numel()
473        else:
474            return result.sum()
475
476    module_inputs = []
477    for desc, constructor_kwargs in cases:
478        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
479            return poissonnllloss_reference_fn(i, t, **constructor_kwargs)
480
481        log_input = constructor_kwargs.get('log_input', True)
482        input = make_input((2, 3, 4, 5)) if log_input else make_input((2, 3, 4, 5)).abs().add(0.001)
483        module_inputs.append(
484            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
485                        forward_input=FunctionInput(input,
486                                                    make_target((2, 3, 4, 5)).floor_().abs_()),
487                        desc=desc,
488                        reference_fn=reference_fn)
489        )
490
491    return module_inputs
492
493
494def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, training, **kwargs):
495    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
496    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
497
498    cases: List[Tuple[str, dict]] = [
499        ('', {}),
500        ('reduction_sum', {'reduction': 'sum'}),
501        ('reduction_mean', {'reduction': 'mean'}),
502        ('reduction_none', {'reduction': 'none'}),
503    ]
504
505    def mse_loss_reference_fn(m, p, i, t, reduction='mean'):
506        if reduction == 'none':
507            return (i - t).pow(2)
508        elif reduction == 'mean':
509            return (i - t).pow(2).sum() / i.numel()
510        else:
511            return (i - t).pow(2).sum()
512
513    module_inputs = []
514    for desc, constructor_kwargs in cases:
515        module_inputs.append(
516            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
517                        forward_input=FunctionInput(make_input((2, 3, 4, 5)),
518                                                    make_target((2, 3, 4, 5))),
519                        desc=desc,
520                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
521        )
522        module_inputs.append(
523            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
524                        forward_input=FunctionInput(make_input(()),
525                                                    make_target(())),
526                        desc=f'{desc}_scalar',
527                        reference_fn=partial(mse_loss_reference_fn, **constructor_kwargs))
528        )
529
530    return module_inputs
531
532
533def no_batch_dim_reference_fn(m, p, *args, **kwargs):
534    """Reference function for modules supporting no batch dimensions.
535
536    Unbatched inputs are unsqueezed to form a
537    single batch input before passing them to the module.
538    The output is squeezed to compare with the
539    output of unbatched input to the module.
540
541    Currently it only supports modules which return a single Tensor as output.
542    You can bind the following kwargs.
543    Kwargs:
544        batch_first[bool] : If True, all the Tensors in `args` while be unsqueezed at dim `0` .
545                        and output will be squeezed at dim `0` else dim `1` for both.
546        kwargs_to_batchify[dict] : Dictionary specifying the name of the argument and dimension to unsqueeze.
547                               Useful if there are few arguments whose batch dimension are different
548                               from the ones selected by `batch_first`.
549        is_criterion[bool] : Specify if the module is a criterion and handle the reduction for output accordingly.
550    """
551    def get_and_pop(key, default):
552        v = kwargs.get(key, default)
553        if key in kwargs:
554            kwargs.pop(key)
555        return v
556
557    batch_dim = 0 if get_and_pop('batch_first', True) else 1
558    kwargs_to_batchify = get_and_pop('kwargs_to_batchify', None)
559    is_criterion = get_and_pop('is_criterion', False)
560
561    if kwargs_to_batchify is not None:
562        assert isinstance(kwargs_to_batchify, dict)
563        for k, v in kwargs.items():
564            if k in kwargs_to_batchify and v is not None:
565                bdim = kwargs_to_batchify[k]
566                kwargs[k] = v.unsqueeze(bdim)
567
568    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
569    with freeze_rng_state():
570        output = m(*single_batch_input_args, **kwargs).squeeze(batch_dim)
571
572    if is_criterion:
573        reduction = get_reduction(m)
574        if reduction == 'none':
575            return output.squeeze(0)
576    return output
577
578
579def no_batch_dim_reference_mha(m, p, *args, **kwargs):
580    """Reference function for MultiheadAttention supporting no batch dimensions.
581
582    Unbatched inputs are unsqueezed to form a
583    single batch input before passing them to the module.
584    The output is squeezed to compare with the
585    output of unbatched input to the module.
586    """
587    batch_dim = 0 if kwargs.get('batch_first', True) else 1
588    if 'batch_first' in kwargs:
589        kwargs.pop('batch_first')
590    if 'key_padding_mask' in kwargs and kwargs['key_padding_mask'] is not None:
591        kwargs['key_padding_mask'] = kwargs['key_padding_mask'].unsqueeze(0)
592    single_batch_input_args = [input.unsqueeze(batch_dim) for input in args]
593    with freeze_rng_state():
594        output = m(*single_batch_input_args, **kwargs)
595        return (output[0].squeeze(batch_dim), output[1].squeeze(0))
596
597
598def no_batch_dim_reference_rnn_gru(m, p, *args, **kwargs):
599    """Reference function for RNN and GRU supporting no batch dimensions.
600
601    Unbatched inputs are unsqueezed to form a
602    single batch input before passing them to the module.
603    The output is squeezed to compare with the
604    output of unbatched input to the module.
605    """
606    if len(args) == 1:
607        inp, = args
608        h = None
609    elif len(args) == 2:
610        inp, h = args
611        h = h.unsqueeze(1)
612
613    batch_dim = 0 if kwargs['batch_first'] else 1
614    kwargs.pop('batch_first')
615    inp = inp.unsqueeze(batch_dim)
616    single_batch_input_args = (inp, h)
617    with freeze_rng_state():
618        output = m(*single_batch_input_args, **kwargs)
619        return (output[0].squeeze(batch_dim), output[1].squeeze(1))
620
621
622def no_batch_dim_reference_lstm(m, p, *args, **kwargs):
623    """Reference function for LSTM supporting no batch dimensions.
624
625    Unbatched inputs are unsqueezed to form a
626    single batch input before passing them to the module.
627    The output is squeezed to compare with the
628    output of unbatched input to the module.
629    """
630    if len(args) == 1:
631        inp, = args
632        h = None
633    elif len(args) == 2:
634        inp, h = args
635        h = (h[0].unsqueeze(1), h[1].unsqueeze(1))
636
637    batch_dim = 0 if kwargs['batch_first'] else 1
638    kwargs.pop('batch_first')
639    inp = inp.unsqueeze(batch_dim)
640    single_batch_input_args = (inp, h)
641    with freeze_rng_state():
642        output = m(*single_batch_input_args, **kwargs)
643        return (output[0].squeeze(batch_dim), (output[1][0].squeeze(1), output[1][1].squeeze(1)))
644
645
646def no_batch_dim_reference_lstmcell(m, p, *args, **kwargs):
647    """Reference function for LSTMCell supporting no batch dimensions.
648
649    The module is passed the input and target in batched form with a single item.
650    The output is squeezed to compare with the no-batch input.
651    """
652    inp, (h, c) = args
653    single_batch_input_args = (inp.unsqueeze(0), (h.unsqueeze(0), c.unsqueeze(0)))
654    with freeze_rng_state():
655        output = m(*single_batch_input_args, **kwargs)
656        return (output[0].squeeze(0), output[1].squeeze(0))
657
658
659def generate_regression_criterion_inputs(make_input):
660    return [
661        ModuleInput(
662            constructor_input=FunctionInput(reduction=reduction),
663            forward_input=FunctionInput(make_input((4, )), make_input(4,)),
664            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True),
665            desc=f'no_batch_dim_{reduction}'
666        ) for reduction in ['none', 'mean', 'sum']]
667
668
669def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
670    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
671
672    return [
673        ModuleInput(constructor_input=FunctionInput(kernel_size=2),
674                    forward_input=FunctionInput(make_input((3, 6))),
675                    desc='no_batch_dim',
676                    reference_fn=no_batch_dim_reference_fn),
677        ModuleInput(constructor_input=FunctionInput(2),
678                    forward_input=FunctionInput(make_input((2, 3, 6)))),
679        ModuleInput(constructor_input=FunctionInput((2,), (2,)),
680                    forward_input=FunctionInput(make_input((2, 3, 6))),
681                    desc='stride'),
682        ModuleInput(constructor_input=FunctionInput(2, 2, 1),
683                    forward_input=FunctionInput(make_input((2, 3, 6))),
684                    desc='stride_pad')]
685
686
687def module_inputs_torch_nn_AvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
688    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
689
690    return [
691        ModuleInput(constructor_input=FunctionInput((2, 2)),
692                    forward_input=FunctionInput(make_input((3, 6, 6))),
693                    desc='no_batch_dim',
694                    reference_fn=no_batch_dim_reference_fn),
695        ModuleInput(constructor_input=FunctionInput((2, 2)),
696                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
697        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2)),
698                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
699                    desc='stride'),
700        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1)),
701                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
702                    desc='stride_pad'),
703        ModuleInput(constructor_input=FunctionInput((2, 2), divisor_override=1),
704                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
705                    desc='divisor'),
706        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), divisor_override=1),
707                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
708                    desc='divisor_stride'),
709        ModuleInput(constructor_input=FunctionInput((2, 2), (2, 2), (1, 1), divisor_override=1),
710                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
711                    desc='divisor_stride_pad')]
712
713
714
715def module_inputs_torch_nn_AvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
716    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
717
718    return [
719        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
720                    forward_input=FunctionInput(make_input((3, 4, 4, 4))),
721                    desc='no_batch_dim',
722                    reference_fn=no_batch_dim_reference_fn),
723        ModuleInput(constructor_input=FunctionInput((2, 2, 2)),
724                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
725        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2)),
726                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
727                    desc='stride'),
728        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
729                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
730                    desc='stride_pad'),
731        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1)),
732                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
733                    desc='stride_pad_gpu_fixedkw_output'),
734        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2)),
735                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
736                    desc='stride_pad_gpu_general_output'),
737        ModuleInput(constructor_input=FunctionInput(3, 1, 0),
738                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
739                    desc='stride1_pad0_gpu_input'),
740        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1)),
741                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
742                    desc='stride_pad_gpu_input_nooverlap'),
743        ModuleInput(constructor_input=FunctionInput((2, 2, 2), divisor_override=1),
744                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
745                    desc='divisor'),
746        ModuleInput(constructor_input=FunctionInput(2, (2, 2, 2), divisor_override=1),
747                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
748                    desc='divisor_stride'),
749        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
750                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
751                    desc='divisor_stride_pad'),
752        ModuleInput(constructor_input=FunctionInput(4, 2, (1, 2, 1), divisor_override=1),
753                    forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
754                    desc='divisor_stride_pad_gpu_fixedkw_output'),
755        ModuleInput(constructor_input=FunctionInput((2, 4, 8), 1, (1, 1, 2), divisor_override=1),
756                    forward_input=FunctionInput(make_input((2, 3, 2, 4, 8))),
757                    desc='divisor_stride_pad_gpu_general_output'),
758        ModuleInput(constructor_input=FunctionInput(3, 1, 0, divisor_override=1),
759                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
760                    desc='divisor_stride1_pad0_gpu_input'),
761        ModuleInput(constructor_input=FunctionInput(2, 2, (1, 1, 1), divisor_override=1),
762                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
763                    desc='divisor_stride_pad_gpu_input_nooverlap')]
764
765
766
767def module_inputs_torch_nn_AdaptiveAvgPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
768    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
769
770    return [
771        ModuleInput(constructor_input=FunctionInput(3,),
772                    forward_input=FunctionInput(make_input((1, 3, 5))),
773                    desc='single'),
774        ModuleInput(constructor_input=FunctionInput(3,),
775                    forward_input=FunctionInput(make_input((3, 5))),
776                    reference_fn=no_batch_dim_reference_fn,
777                    desc='no_batch_dim'),
778        ModuleInput(constructor_input=FunctionInput(1,),
779                    forward_input=FunctionInput(make_input((1, 3, 5))),
780                    desc='one_output')]
781
782
783def module_inputs_torch_nn_AdaptiveAvgPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
784    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
785
786    return [
787        ModuleInput(constructor_input=FunctionInput(3,),
788                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
789                    desc='single'),
790        ModuleInput(constructor_input=FunctionInput(3,),
791                    forward_input=FunctionInput(make_input((3, 5, 6))),
792                    reference_fn=no_batch_dim_reference_fn,
793                    desc='no_batch_dim'),
794        ModuleInput(constructor_input=FunctionInput(1,),
795                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
796                    desc='single_1x1output'),
797        ModuleInput(constructor_input=FunctionInput((3, 4)),
798                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
799                    desc='tuple'),
800        ModuleInput(constructor_input=FunctionInput((3, None)),
801                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
802                    desc='tuple_none')]
803
804def module_inputs_torch_nn_AdaptiveAvgPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
805    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
806
807    return [
808        ModuleInput(constructor_input=FunctionInput(3,),
809                    forward_input=FunctionInput(make_input((2, 3, 5, 2, 7))),
810                    desc='single'),
811        ModuleInput(constructor_input=FunctionInput(3,),
812                    forward_input=FunctionInput(make_input((3, 5, 2, 7))),
813                    reference_fn=no_batch_dim_reference_fn,
814                    desc='no_batch_dim'),
815        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
816                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
817                    desc='tuple'),
818        ModuleInput(constructor_input=FunctionInput((None, 4, 5)),
819                    forward_input=FunctionInput(make_input((2, 3, 5, 3, 7))),
820                    desc='tuple_none'),
821        ModuleInput(constructor_input=FunctionInput((3, 2, 2)),
822                    forward_input=FunctionInput(make_input((1, 1, 3, 2, 6))),
823                    desc='last_dim')]
824
825
826def module_inputs_torch_nn_AdaptiveMaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
827    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
828
829    return [
830        ModuleInput(constructor_input=FunctionInput(3,),
831                    forward_input=FunctionInput(make_input((1, 3, 5))),
832                    desc='single'),
833        ModuleInput(constructor_input=FunctionInput(3,),
834                    forward_input=FunctionInput(make_input((3, 5))),
835                    reference_fn=no_batch_dim_reference_fn,
836                    desc='no_batch_dim')]
837
838
839def module_inputs_torch_nn_AdaptiveMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
840    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
841
842    return [
843        ModuleInput(constructor_input=FunctionInput(3,),
844                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
845                    desc='single'),
846        ModuleInput(constructor_input=FunctionInput(3,),
847                    forward_input=FunctionInput(make_input((3, 5, 6))),
848                    reference_fn=no_batch_dim_reference_fn,
849                    desc='no_batch_dim'),
850        ModuleInput(constructor_input=FunctionInput((3, 4)),
851                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
852                    desc='tuple'),
853        ModuleInput(constructor_input=FunctionInput((3, None)),
854                    forward_input=FunctionInput(make_input((1, 3, 5, 6))),
855                    desc='tuple_none')]
856
857
858def module_inputs_torch_nn_AdaptiveMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
859    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
860
861    return [
862        ModuleInput(constructor_input=FunctionInput(3,),
863                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
864                    desc='single'),
865        ModuleInput(constructor_input=FunctionInput(3,),
866                    forward_input=FunctionInput(make_input((3, 5, 6, 7))),
867                    reference_fn=no_batch_dim_reference_fn,
868                    desc='no_batch_dim'),
869        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
870                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
871                    desc='tuple'),
872        ModuleInput(constructor_input=FunctionInput((3, None, 5)),
873                    forward_input=FunctionInput(make_input((2, 3, 5, 6, 7))),
874                    desc='tuple_none'),
875        ModuleInput(constructor_input=FunctionInput(3),
876                    forward_input=FunctionInput(make_input((2, 3, 12, 9, 3))),
877                    desc='single_nonatomic'),
878        ModuleInput(constructor_input=FunctionInput((3, 4, 5)),
879                    forward_input=FunctionInput(make_input((2, 3, 6, 4, 10))),
880                    desc='tuple_nonatomic')]
881
882
883def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad, training, **kwargs):
884    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
885
886    return [
887        ModuleInput(constructor_input=FunctionInput(10,),
888                    forward_input=FunctionInput(make_input((4, 10))),
889                    desc='affine'),
890        ModuleInput(constructor_input=FunctionInput(5,),
891                    forward_input=FunctionInput(make_input((4, 5, 3))),
892                    desc='3d_input'),
893        ModuleInput(constructor_input=FunctionInput(10, 1e-3, None),
894                    forward_input=FunctionInput(make_input((4, 10))),
895                    desc='affine_simple_average'),
896        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, False),
897                    forward_input=FunctionInput(make_input((4, 10))),
898                    desc='not_affine'),
899        ModuleInput(constructor_input=FunctionInput(10, 1e-3, 0.3, True, False),
900                    forward_input=FunctionInput(make_input((4, 10))),
901                    desc='not_tracking_stats'),
902        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
903                    forward_input=FunctionInput(make_input((4, 5, 3))),
904                    desc='3d_input_not_affine'),
905        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
906                    forward_input=FunctionInput(make_input((0, 5, 9))),
907                    desc='zero_batch')]
908
909
910def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs):
911    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
912
913    return [
914        ModuleInput(constructor_input=FunctionInput(3,),
915                    forward_input=FunctionInput(make_input((2, 3, 6, 6)))),
916        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
917                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
918                    desc='2d_simple_average'),
919        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8),
920                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
921                    desc='momentum'),
922        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, False),
923                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
924                    desc='not_affine'),
925        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.8, True, False),
926                    forward_input=FunctionInput(make_input((2, 3, 6, 6))),
927                    desc='not_tracking_stats'),
928        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
929                    forward_input=FunctionInput(make_input((0, 5, 2, 2))),
930                    desc='zero_batch')]
931
932
933def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs):
934    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
935
936    return [
937        ModuleInput(constructor_input=FunctionInput(3,),
938                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4)))),
939        ModuleInput(constructor_input=FunctionInput(3, 1e-3, None),
940                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
941                    desc='3d_simple_average'),
942        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7),
943                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
944                    desc='momentum'),
945        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, False),
946                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
947                    desc='not_affine'),
948        ModuleInput(constructor_input=FunctionInput(3, 1e-3, 0.7, True, False),
949                    forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))),
950                    desc='not_tracking_stats'),
951        ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False),
952                    forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))),
953                    desc='zero_batch')]
954
955
956def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs):
957    N = kwargs['N']
958    lazy = kwargs.get('lazy', False)
959    transposed = kwargs.get('transposed', False)
960    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
961    conv_kwargs_list = [{}] if transposed else [{}, {'padding': 'same'}]
962    kernel_size, C_in, C_out = 3, 4, 5
963    input_no_batch_shape = (C_in,) + tuple(i + 3 for i in range(N))
964    input_batch_shape = (2,) + input_no_batch_shape
965    return [
966        ModuleInput(constructor_input=(FunctionInput(C_out, kernel_size, **conv_kwargs) if lazy else
967                                       FunctionInput(C_in, C_out, kernel_size, **conv_kwargs)),
968                    forward_input=FunctionInput(make_input(
969                        input_batch_shape if with_batch else input_no_batch_shape)),
970                    desc=('' if with_batch else 'no_batch_dim'),
971                    reference_fn=(None if with_batch else no_batch_dim_reference_fn))
972        for with_batch, conv_kwargs in itertools.product([True, False], conv_kwargs_list)
973    ]
974
975
976def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
977    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
978    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
979
980    cases: List[Tuple[str, dict]] = [
981        ('', {}),
982        ('reduction_sum', {'reduction': 'sum'}),
983        ('reduction_mean', {'reduction': 'mean'}),
984        ('reduction_none', {'reduction': 'none'}),
985        ('margin', {'margin': 0.7})
986    ]
987
988    module_inputs = []
989    for desc, constructor_kwargs in cases:
990        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
991            return cosineembeddingloss_reference(i1, i2, t, **constructor_kwargs)
992
993        module_inputs.append(
994            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
995                        forward_input=FunctionInput(make_input((15, 10)), make_input((15, 10)),
996                                                    make_target((15,)).sign()),
997                        desc=desc,
998                        reference_fn=reference_fn)
999        )
1000
1001    return module_inputs
1002
1003
1004def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, training, **kwargs):
1005    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1006
1007    return [
1008        ModuleInput(constructor_input=FunctionInput(alpha=2.),
1009                    forward_input=FunctionInput(make_input((3, 2, 5))),
1010                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
1011        ModuleInput(constructor_input=FunctionInput(alpha=2.),
1012                    forward_input=FunctionInput(make_input(())),
1013                    desc='scalar'),
1014        ModuleInput(constructor_input=FunctionInput(),
1015                    forward_input=FunctionInput(make_input((3,))),
1016                    desc='no_batch_dim',
1017                    reference_fn=no_batch_dim_reference_fn),
1018        ModuleInput(constructor_input=FunctionInput(alpha=2.),
1019                    forward_input=FunctionInput(make_input((2, 3, 2, 5))),
1020                    desc='4d_input')]
1021
1022
1023def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, training, **kwargs):
1024    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1025
1026    return [
1027        ModuleInput(constructor_input=FunctionInput(alpha=2.),
1028                    forward_input=FunctionInput(make_input((3, 2, 5))),
1029                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
1030        ModuleInput(constructor_input=FunctionInput(alpha=2.),
1031                    forward_input=FunctionInput(make_input(())),
1032                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1)),
1033                    desc='scalar'),
1034        ModuleInput(constructor_input=FunctionInput(alpha=2.),
1035                    forward_input=FunctionInput(make_input((3,))),
1036                    desc='no_batch_dim',
1037                    reference_fn=no_batch_dim_reference_fn)]
1038
1039
1040def module_inputs_torch_nn_GLU(module_info, device, dtype, requires_grad, training, **kwargs):
1041    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1042
1043    return [
1044        ModuleInput(constructor_input=FunctionInput(),
1045                    forward_input=FunctionInput(make_input((5, 6)))),
1046        ModuleInput(constructor_input=FunctionInput(1),
1047                    forward_input=FunctionInput(make_input((5, 6, 7))),
1048                    desc='dim'),
1049        ModuleInput(constructor_input=FunctionInput(),
1050                    forward_input=FunctionInput(make_input((4,))),
1051                    desc='no_batch_dim',
1052                    reference_fn=no_batch_dim_reference_fn)]
1053
1054
1055def module_inputs_torch_nn_GELU(module_info, device, dtype, requires_grad, training, **kwargs):
1056    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1057
1058    return [
1059        ModuleInput(constructor_input=FunctionInput('none'),
1060                    forward_input=FunctionInput(make_input(())),
1061                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
1062                    desc='scalar'),
1063        ModuleInput(constructor_input=FunctionInput('none'),
1064                    forward_input=FunctionInput(make_input((3, 2, 5))),
1065                    reference_fn=lambda m, p, x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))),
1066        ModuleInput(constructor_input=FunctionInput(),
1067                    forward_input=FunctionInput(make_input((3,))),
1068                    desc='no_batch_dim',
1069                    reference_fn=no_batch_dim_reference_fn)]
1070
1071
1072def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad, training, **kwargs):
1073    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1074
1075    return [
1076        ModuleInput(constructor_input=FunctionInput(),
1077                    forward_input=FunctionInput(make_input(())),
1078                    desc='scalar'),
1079        ModuleInput(constructor_input=FunctionInput(),
1080                    forward_input=FunctionInput(make_input(4)),
1081                    reference_fn=no_batch_dim_reference_fn,
1082                    desc='no_batch_dim'),
1083        ModuleInput(constructor_input=FunctionInput(),
1084                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
1085                    desc='channels_last_mem_format'),
1086        ModuleInput(constructor_input=FunctionInput(),
1087                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
1088                    desc='channels_last_3d_mem_format')]
1089
1090
1091def module_inputs_torch_nn_ReLU6(module_info, device, dtype, requires_grad, training, **kwargs):
1092    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1093
1094    return [
1095        ModuleInput(constructor_input=FunctionInput(),
1096                    forward_input=FunctionInput(make_input(())),
1097                    desc='scalar'),
1098        ModuleInput(constructor_input=FunctionInput(),
1099                    forward_input=FunctionInput(make_input(4)),
1100                    reference_fn=no_batch_dim_reference_fn,
1101                    desc='no_batch_dim'),
1102        ModuleInput(constructor_input=FunctionInput(),
1103                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
1104                    desc='channels_last_mem_format'),
1105        ModuleInput(constructor_input=FunctionInput(),
1106                    forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
1107                    desc='channels_last_3d_mem_format')]
1108
1109
1110def module_inputs_torch_nn_LeakyReLU(module_info, device, dtype, requires_grad, training, **kwargs):
1111    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1112
1113    return [
1114        ModuleInput(constructor_input=FunctionInput(),
1115                    forward_input=FunctionInput(make_input((3, 2, 5)))),
1116        ModuleInput(constructor_input=FunctionInput(),
1117                    forward_input=FunctionInput(make_input(4)),
1118                    reference_fn=no_batch_dim_reference_fn,
1119                    desc='no_batch_dim'),
1120        ModuleInput(constructor_input=FunctionInput(0.5),
1121                    forward_input=FunctionInput(make_input((3, 2, 5))),
1122                    desc='with_negval'),
1123        ModuleInput(constructor_input=FunctionInput(0.0),
1124                    forward_input=FunctionInput(make_input((10, 10))),
1125                    desc='with_zero_negval'),
1126        ModuleInput(constructor_input=FunctionInput(0.5),
1127                    forward_input=FunctionInput(make_input(())),
1128                    desc='with_negval_scalar')]
1129
1130
1131def module_inputs_torch_nn_PReLU(module_info, device, dtype, requires_grad, training, **kwargs):
1132    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1133
1134    return [
1135        ModuleInput(constructor_input=FunctionInput(),
1136                    forward_input=FunctionInput(make_input(())),
1137                    desc='scalar'),
1138        ModuleInput(constructor_input=FunctionInput(),
1139                    forward_input=FunctionInput(make_input(4)),
1140                    reference_fn=no_batch_dim_reference_fn,
1141                    desc='no_batch_dim'),
1142        ModuleInput(constructor_input=FunctionInput(),
1143                    forward_input=FunctionInput(make_input((2, 3, 4))),
1144                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
1145                    desc='1d'),
1146        ModuleInput(constructor_input=FunctionInput(3),
1147                    forward_input=FunctionInput(make_input((2, 3, 4))),
1148                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
1149                    desc='1d_multiparam'),
1150        ModuleInput(constructor_input=FunctionInput(),
1151                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
1152                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
1153                    desc='2d'),
1154        ModuleInput(constructor_input=FunctionInput(3),
1155                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
1156                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
1157                    desc='2d_multiparam'),
1158        ModuleInput(constructor_input=FunctionInput(),
1159                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
1160                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
1161                    desc='3d'),
1162        ModuleInput(constructor_input=FunctionInput(3),
1163                    forward_input=FunctionInput(make_input((2, 3, 4, 5, 6))),
1164                    reference_fn=lambda m, p, i: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
1165                    desc='3d_multiparam')]
1166
1167
1168def module_inputs_torch_nn_SELU(module_info, device, dtype, requires_grad, training, **kwargs):
1169    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1170
1171    return [
1172        ModuleInput(constructor_input=FunctionInput(),
1173                    forward_input=FunctionInput(make_input((3, 2, 5)))),
1174        ModuleInput(constructor_input=FunctionInput(),
1175                    forward_input=FunctionInput(make_input(4)),
1176                    reference_fn=no_batch_dim_reference_fn,
1177                    desc='no_batch_dim'),
1178        ModuleInput(constructor_input=FunctionInput(),
1179                    forward_input=FunctionInput(make_input(())),
1180                    desc='scalar')]
1181
1182
1183def module_inputs_torch_nn_SiLU(module_info, device, dtype, requires_grad, training, **kwargs):
1184    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1185
1186    return [
1187        ModuleInput(constructor_input=FunctionInput(),
1188                    forward_input=FunctionInput(make_input(())),
1189                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x),
1190                    desc='scalar'),
1191        ModuleInput(constructor_input=FunctionInput(),
1192                    forward_input=FunctionInput(make_input(4)),
1193                    reference_fn=no_batch_dim_reference_fn,
1194                    desc='no_batch_dim'),
1195        ModuleInput(constructor_input=FunctionInput(),
1196                    forward_input=FunctionInput(make_input((5, 6, 7))),
1197                    reference_fn=lambda m, p, x, *_: x * torch.sigmoid(x))]
1198
1199
1200def module_inputs_torch_nn_Softmax(module_info, device, dtype, requires_grad, training, **kwargs):
1201    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1202
1203    return [
1204        ModuleInput(constructor_input=FunctionInput(1),
1205                    forward_input=FunctionInput(make_input((10, 20))),
1206                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20))),
1207        ModuleInput(constructor_input=FunctionInput(0),
1208                    forward_input=FunctionInput(make_input(())),
1209                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(0, True)),
1210                    desc='scalar'),
1211        ModuleInput(constructor_input=FunctionInput(-1),
1212                    forward_input=FunctionInput(make_input((4, 5))),
1213                    reference_fn=no_batch_dim_reference_fn,
1214                    desc='no_batch_dim')]
1215
1216
1217def module_inputs_torch_nn_Softmax2d(module_info, device, dtype, requires_grad, training, **kwargs):
1218    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1219
1220    return [
1221        ModuleInput(constructor_input=FunctionInput(),
1222                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
1223                    reference_fn=lambda m, p, i: torch.exp(i).div(torch.exp(i).sum(1, False))),
1224        ModuleInput(constructor_input=FunctionInput(),
1225                    forward_input=FunctionInput(make_input((3, 4, 5))),
1226                    reference_fn=no_batch_dim_reference_fn,
1227                    desc='no_batch_dim')]
1228
1229
1230def module_inputs_torch_nn_LogSoftmax(module_info, device, dtype, requires_grad, training, **kwargs):
1231    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1232
1233    return [
1234        ModuleInput(constructor_input=FunctionInput(1),
1235                    forward_input=FunctionInput(make_input((10, 20))),
1236                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_()),
1237        ModuleInput(constructor_input=FunctionInput(1),
1238                    forward_input=FunctionInput(make_input((1, 3, 10, 20))),
1239                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
1240                    desc='multiparam'),
1241        ModuleInput(constructor_input=FunctionInput(0),
1242                    forward_input=FunctionInput(make_input(())),
1243                    reference_fn=lambda m, p, i: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
1244                    desc='multiparam_scalar'),
1245        ModuleInput(constructor_input=FunctionInput(-1),
1246                    forward_input=FunctionInput(make_input((4, 5))),
1247                    reference_fn=no_batch_dim_reference_fn,
1248                    desc='no_batch_dim')]
1249
1250
1251def module_inputs_torch_nn_Softmin(module_info, device, dtype, requires_grad, training, **kwargs):
1252    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1253
1254    return [
1255        ModuleInput(constructor_input=FunctionInput(1),
1256                    forward_input=FunctionInput(make_input((10, 20)))),
1257        ModuleInput(constructor_input=FunctionInput(1),
1258                    forward_input=FunctionInput(make_input((2, 3, 5, 10))),
1259                    desc='multidim'),
1260        ModuleInput(constructor_input=FunctionInput(0),
1261                    forward_input=FunctionInput(make_input(())),
1262                    desc='scalar'),
1263        ModuleInput(constructor_input=FunctionInput(-1),
1264                    forward_input=FunctionInput(make_input((3, 4, 10))),
1265                    reference_fn=no_batch_dim_reference_fn,
1266                    desc='no_batch_dim')]
1267
1268
1269def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, training, **kwargs):
1270    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1271
1272    return [
1273        ModuleInput(constructor_input=FunctionInput(),
1274                    forward_input=FunctionInput(make_input((10, 20))),
1275                    reference_fn=lambda m, p, i: torch.log(1 + torch.exp(i))),
1276        ModuleInput(constructor_input=FunctionInput(2),
1277                    forward_input=FunctionInput(make_input((10, 20))),
1278                    reference_fn=lambda m, p, i: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
1279                    desc='beta'),
1280        ModuleInput(constructor_input=FunctionInput(2, -100),
1281                    forward_input=FunctionInput(make_input((10, 20))),
1282                    reference_fn=(
1283                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
1284                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
1285                    desc='beta_threshold'),
1286        ModuleInput(constructor_input=FunctionInput(2, -100),
1287                    forward_input=FunctionInput(make_input(())),
1288                    reference_fn=(
1289                        lambda m, p, i: ((i * 2) > -100).type_as(i) * i
1290                        + ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
1291                    desc='beta_threshold_scalar'),
1292        ModuleInput(constructor_input=FunctionInput(),
1293                    forward_input=FunctionInput(make_input(4)),
1294                    reference_fn=no_batch_dim_reference_fn,
1295                    desc='no_batch_dim')]
1296
1297
1298def module_inputs_torch_nn_Softshrink(module_info, device, dtype, requires_grad, training, **kwargs):
1299    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1300
1301    return [
1302        ModuleInput(constructor_input=FunctionInput(),
1303                    forward_input=FunctionInput(make_input((3, 2, 5)))),
1304        ModuleInput(constructor_input=FunctionInput(1,),
1305                    forward_input=FunctionInput(make_input((3, 2, 5))),
1306                    desc='lambda'),
1307        ModuleInput(constructor_input=FunctionInput(1,),
1308                    forward_input=FunctionInput(make_input(())),
1309                    desc='lambda_scalar'),
1310        ModuleInput(constructor_input=FunctionInput(),
1311                    forward_input=FunctionInput(make_input(4)),
1312                    reference_fn=no_batch_dim_reference_fn,
1313                    desc='no_batch_dim')]
1314
1315
1316def module_inputs_torch_nn_Softsign(module_info, device, dtype, requires_grad, training, **kwargs):
1317    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1318
1319    return [
1320        ModuleInput(constructor_input=FunctionInput(),
1321                    forward_input=FunctionInput(make_input((3, 2, 5))),
1322                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i))),
1323        ModuleInput(constructor_input=FunctionInput(),
1324                    forward_input=FunctionInput(make_input(())),
1325                    reference_fn=lambda m, p, i: i.div(1 + torch.abs(i)),
1326                    desc='scalar'),
1327        ModuleInput(constructor_input=FunctionInput(),
1328                    forward_input=FunctionInput(make_input(4)),
1329                    reference_fn=no_batch_dim_reference_fn,
1330                    desc='no_batch_dim')]
1331
1332
1333def module_inputs_torch_nn_Tanh(module_info, device, dtype, requires_grad, training, **kwargs):
1334    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1335
1336    return [
1337        ModuleInput(constructor_input=FunctionInput(),
1338                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
1339        ModuleInput(constructor_input=FunctionInput(),
1340                    forward_input=FunctionInput(make_input(())),
1341                    desc='scalar'),
1342        ModuleInput(constructor_input=FunctionInput(),
1343                    forward_input=FunctionInput(make_input(4)),
1344                    reference_fn=no_batch_dim_reference_fn,
1345                    desc='no_batch_dim')]
1346
1347
1348
1349def module_inputs_torch_nn_Tanhshrink(module_info, device, dtype, requires_grad, training, **kwargs):
1350    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1351
1352    return [
1353        ModuleInput(constructor_input=FunctionInput(),
1354                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
1355        ModuleInput(constructor_input=FunctionInput(),
1356                    forward_input=FunctionInput(make_input(())),
1357                    desc='scalar'),
1358        ModuleInput(constructor_input=FunctionInput(),
1359                    forward_input=FunctionInput(make_input(4)),
1360                    reference_fn=no_batch_dim_reference_fn,
1361                    desc='no_batch_dim')]
1362
1363
1364def module_inputs_torch_nn_Threshold(module_info, device, dtype, requires_grad, training, **kwargs):
1365    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1366
1367    return [
1368        ModuleInput(constructor_input=FunctionInput(2., 1.),
1369                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
1370                    desc='threshold_value'),
1371        ModuleInput(constructor_input=FunctionInput(2., 10.),
1372                    forward_input=FunctionInput(make_input((2, 3, 4, 5))),
1373                    desc='large_value'),
1374        ModuleInput(constructor_input=FunctionInput(2., 1.),
1375                    forward_input=FunctionInput(make_input(())),
1376                    desc='threshold_value_scalar'),
1377        ModuleInput(constructor_input=FunctionInput(2., 1.),
1378                    forward_input=FunctionInput(make_input(4)),
1379                    reference_fn=no_batch_dim_reference_fn,
1380                    desc='no_batch_dim')]
1381
1382
1383def module_inputs_torch_nn_Mish(module_info, device, dtype, requires_grad, training, **kwargs):
1384    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1385
1386    return [
1387        ModuleInput(constructor_input=FunctionInput(),
1388                    forward_input=FunctionInput(make_input((5, 6, 7))),
1389                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i))),
1390        ModuleInput(constructor_input=FunctionInput(),
1391                    forward_input=FunctionInput(make_input(())),
1392                    reference_fn=lambda m, p, i: i * torch.tanh(F.softplus(i)),
1393                    desc='scalar'),
1394        ModuleInput(constructor_input=FunctionInput(),
1395                    forward_input=FunctionInput(make_input(4)),
1396                    reference_fn=no_batch_dim_reference_fn,
1397                    desc='no_batch_dim')]
1398
1399
1400def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
1401    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1402
1403    return [
1404        ModuleInput(constructor_input=FunctionInput(),
1405                    forward_input=FunctionInput(make_input((2, 3, 4)),
1406                                                make_input((2, 3, 4))),
1407                    reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
1408                                                                         for a, b in zip(i, t))),
1409        ModuleInput(constructor_input=FunctionInput(),
1410                    forward_input=FunctionInput(make_input(()), make_input(())),
1411                    reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
1412                    desc='scalar')] + generate_regression_criterion_inputs(make_input)
1413
1414
1415def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_grad, training, **kwargs):
1416    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1417
1418
1419    cases: List[Tuple[str, dict]] = [
1420        ('', {}),
1421        ('reduction_sum', {'reduction': 'sum'}),
1422        ('reduction_mean', {'reduction': 'mean'}),
1423        ('reduction_none', {'reduction': 'none'}),
1424    ]
1425
1426    module_inputs = []
1427    for desc, constructor_kwargs in cases:
1428        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
1429            return smoothl1loss_reference(i, t, **constructor_kwargs)
1430
1431        module_inputs.append(
1432            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
1433                        forward_input=FunctionInput(make_input((5, 10)),
1434                                                    make_input((5, 10))),
1435                        desc=desc,
1436                        reference_fn=reference_fn)
1437        )
1438        module_inputs.append(
1439            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
1440                        forward_input=FunctionInput(make_input(()),
1441                                                    make_input(())),
1442                        desc=f'scalar_{desc}',
1443                        reference_fn=reference_fn)
1444        )
1445
1446    return module_inputs
1447
1448
1449
1450def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, training, **kwargs):
1451    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1452    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
1453    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
1454
1455    cases: List[Tuple[str, dict]] = [
1456        ('', {}),
1457        ('reduction_sum', {'reduction': 'sum'}),
1458        ('reduction_mean', {'reduction': 'mean'}),
1459        ('reduction_none', {'reduction': 'none'}),
1460        ('weights', {'weight': make_weight((10,))}),
1461    ]
1462
1463    def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
1464        result = -(t * i.log() + (1 - t) * (1 - i).log())
1465
1466        if weight is not None:
1467            result = result * weight
1468
1469        if reduction == 'none':
1470            return result
1471        elif reduction == 'mean':
1472            return result.sum() / i.numel()
1473        else:
1474            return result.sum()
1475
1476    module_inputs = []
1477    for desc, constructor_kwargs in cases:
1478        module_inputs.append(
1479            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
1480                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
1481                                                    make_target((15, 10)).gt(0).to(dtype)),
1482                        desc=desc,
1483                        reference_fn=partial(bce_loss_reference_fn, **constructor_kwargs))
1484        )
1485
1486    scalar_weight = make_weight(())
1487    module_inputs.append(
1488        ModuleInput(constructor_input=FunctionInput(weight=scalar_weight),
1489                    forward_input=FunctionInput(make_input((), low=1e-2, high=1 - 1e-2),
1490                                                make_target(()).gt(0).to(dtype)),
1491                    desc='scalar_weight',
1492                    reference_fn=partial(bce_loss_reference_fn, weight=scalar_weight))
1493    )
1494
1495    return module_inputs
1496
1497
1498def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, requires_grad, training, **kwargs):
1499    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1500    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
1501    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
1502
1503    cases: List[Tuple[str, dict]] = [
1504        ('', {}),
1505        ('reduction_sum', {'reduction': 'sum'}),
1506        ('reduction_mean', {'reduction': 'mean'}),
1507        ('reduction_none', {'reduction': 'none'}),
1508        ('weights', {'weight': make_weight((10,))}),
1509        ('scalar_weights', {'weight': make_weight(())})
1510    ]
1511
1512    def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None):
1513        # TODO: add pos_weight to the definition here and corresponding SampleInputs
1514        max_val = (-i).clamp(min=0)
1515        result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_())
1516
1517        if weight is not None:
1518            result = result * weight
1519
1520        if reduction == 'none':
1521            return result
1522        elif reduction == 'mean':
1523            return result.sum() / i.numel()
1524        else:
1525            return result.sum()
1526
1527    module_inputs = []
1528    for desc, constructor_kwargs in cases:
1529        module_inputs.append(
1530            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
1531                        forward_input=FunctionInput(make_input((15, 10), low=1e-2, high=1 - 1e-2),
1532                                                    make_target((15, 10)).gt(0).to(dtype)),
1533                        desc=desc,
1534                        reference_fn=partial(bce_withlogitsloss_reference_fn, **constructor_kwargs))
1535        )
1536
1537    return module_inputs
1538
1539
1540def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires_grad, training, **kwargs):
1541    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1542    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
1543    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
1544
1545    reductions: List[str] = ['mean', 'sum', 'none']
1546    cases: List[Tuple[str, dict]] = [
1547        ('', {}),
1548        ('weights', {'weight': make_weight((3,))}),
1549        ('ignore_index', {'ignore_index': 1}),
1550        ('label_smoothing', {'label_smoothing': 0.15}),
1551        ('ignore_index_label_smoothing', {'ignore_index': 1, 'label_smoothing': 0.15})
1552    ]
1553
1554    module_inputs = []
1555    for reduction, (desc, constructor_kwargs) in product(reductions, cases):
1556        def reference_fn(m, p, i, t, reduction=reduction, constructor_kwargs=constructor_kwargs):
1557            return cross_entropy_loss_reference(i, t, reduction=reduction, **constructor_kwargs)
1558
1559        module_inputs.append(
1560            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1561                        forward_input=FunctionInput(make_input((2, 3, 5, 5)),
1562                                                    make_target((2, 5, 5), low=0, high=3)),
1563                        desc=f"4d_{desc}_{reduction}",
1564                        reference_fn=reference_fn)
1565        )
1566        module_inputs.append(
1567            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1568                        forward_input=FunctionInput(make_input((2, 3, 5)),
1569                                                    make_target((2, 5), low=0, high=3)),
1570                        desc=f"3d_{desc}_{reduction}",
1571                        reference_fn=reference_fn)
1572        )
1573        module_inputs.append(
1574            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1575                        forward_input=FunctionInput(make_input((2, 3)),
1576                                                    make_target((2), low=0, high=3)),
1577                        desc=f"2d_{desc}_{reduction}",
1578                        reference_fn=reference_fn)
1579        )
1580        module_inputs.append(
1581            ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1582                        forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
1583                                                    make_target((2, 5, 5, 2, 2), low=0, high=3)),
1584                        desc=f"higher_dim_{desc}_{reduction}",
1585                        reference_fn=reference_fn)
1586        )
1587
1588        if constructor_kwargs.get('ignore_index', None) is None:
1589            module_inputs.append(
1590                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1591                            forward_input=FunctionInput(make_input((5, 3, 4, 2)),
1592                                                        make_input((5, 3, 4, 2)).softmax(dim=1)),
1593                            desc=f"4d_prob_target_{desc}_{reduction}",
1594                            reference_fn=reference_fn)
1595            )
1596            module_inputs.append(
1597                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1598                            forward_input=FunctionInput(make_input((5, 3, 4)),
1599                                                        make_input((5, 3, 4)).softmax(dim=1)),
1600                            desc=f"3d_prob_target_{desc}_{reduction}",
1601                            reference_fn=reference_fn)
1602            )
1603            module_inputs.append(
1604                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1605                            forward_input=FunctionInput(make_input((5, 3)),
1606                                                        make_input((5, 3)).softmax(dim=1)),
1607                            desc=f"2d_prob_target_{desc}_{reduction}",
1608                            reference_fn=reference_fn)
1609            )
1610            module_inputs.append(
1611                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1612                            forward_input=FunctionInput(make_input((2, 3, 5, 5, 2, 2)),
1613                                                        make_input((2, 3, 5, 5, 2, 2)).softmax(dim=1)),
1614                            desc=f"higher_dim_prob_target_{desc}_{reduction}",
1615                            reference_fn=reference_fn)
1616            )
1617            module_inputs.append(
1618                ModuleInput(constructor_input=FunctionInput(reduction=reduction, **constructor_kwargs),
1619                            forward_input=FunctionInput(make_input((3,)),
1620                                                        make_target((), low=0, high=3)),
1621                            desc=f"no_batch_dim_{desc}_{reduction}",
1622                            reference_fn=partial(no_batch_dim_reference_fn, is_criterion=True))
1623            )
1624
1625    return module_inputs
1626
1627
1628
1629def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, training, **kwargs):
1630    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1631    make_target = partial(make_tensor, device=device, requires_grad=False)
1632
1633    cases: List[Tuple[str, dict]] = [
1634        ('', {}),
1635        ('reduction_sum', {'reduction': 'sum'}),
1636        ('reduction_mean', {'reduction': 'mean'}),
1637        ('reduction_none', {'reduction': 'none'}),
1638        ('blank', {'blank': 14})
1639    ]
1640    target_dtypes = [torch.int, torch.long]
1641
1642    module_inputs = []
1643    for target_dtype, (desc, constructor_kwargs) in product(target_dtypes, cases):
1644        def reference_fn(m, p, i, t, il, tl, constructor_kwargs=constructor_kwargs):
1645            return ctcloss_reference(i, t, il, tl, **constructor_kwargs)
1646
1647        blank = constructor_kwargs.get('blank', 0)
1648        low = 0 if blank == 14 else 1
1649        high = 14 if blank == 14 else 15
1650
1651        module_inputs.append(
1652            ModuleInput(
1653                constructor_input=FunctionInput(**constructor_kwargs),
1654                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
1655                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
1656                                            (50, 50, 50), (30, 25, 20)),
1657                desc=f'{desc}_lengths_intlists',
1658                reference_fn=reference_fn)
1659        )
1660        module_inputs.append(
1661            ModuleInput(
1662                constructor_input=FunctionInput(**constructor_kwargs),
1663                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
1664                                            make_target((3, 30), dtype=target_dtype, low=low, high=high),
1665                                            torch.tensor((50, 50, 50), device=device),
1666                                            torch.tensor((30, 25, 20), device=device)),
1667                desc=f'{desc}_lengths_tensors',
1668                reference_fn=reference_fn)
1669        )
1670        module_inputs.append(
1671            ModuleInput(
1672                constructor_input=FunctionInput(**constructor_kwargs),
1673                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
1674                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
1675                                            (50, 50, 50), (30, 25, 20)),
1676                desc=f'{desc}_1d_target_lengths_intlists',
1677                reference_fn=reference_fn)
1678        )
1679        module_inputs.append(
1680            ModuleInput(
1681                constructor_input=FunctionInput(**constructor_kwargs),
1682                forward_input=FunctionInput(make_input((50, 3, 15)).log_softmax(2),
1683                                            make_target((30 + 25 + 20,), dtype=target_dtype, low=low, high=high),
1684                                            torch.tensor((50, 50, 50), device=device),
1685                                            torch.tensor((30, 25, 20), device=device)),
1686                desc=f'{desc}_1d_target_lengths_tensors',
1687                reference_fn=reference_fn)
1688        )
1689
1690    return module_inputs
1691
1692
1693def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs):
1694    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1695
1696    return [
1697        ModuleInput(
1698            constructor_input=FunctionInput(3, 6, 1e-3),
1699            forward_input=FunctionInput(make_input((4, 6, 5))),
1700            desc='1d_affine'),
1701        ModuleInput(
1702            constructor_input=FunctionInput(3, 12, 1e-3),
1703            forward_input=FunctionInput(make_input((4, 12))),
1704            desc='1d_affine_GN'),
1705        ModuleInput(
1706            constructor_input=FunctionInput(1, 6, 1e-3),
1707            forward_input=FunctionInput(make_input((150, 6))),
1708            desc='1d_affine_large_batch'),
1709        ModuleInput(
1710            constructor_input=FunctionInput(5, 5, 1e-3, False),
1711            forward_input=FunctionInput(make_input((4, 5, 5))),
1712            desc='1d_no_affine_IN'),
1713        ModuleInput(
1714            constructor_input=FunctionInput(1, 10, 1e-3, False),
1715            forward_input=FunctionInput(make_input((4, 10))),
1716            desc='1d_no_affine_LN'),
1717        ModuleInput(
1718            constructor_input=FunctionInput(3, 6, 1e-3),
1719            forward_input=FunctionInput(make_input((4, 6, 2, 3))),
1720            desc='2d_affine'),
1721        ModuleInput(
1722            constructor_input=FunctionInput(3, 3, 1e-3, False),
1723            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
1724            desc='2d_no_affine_IN'),
1725        ModuleInput(
1726            constructor_input=FunctionInput(1, 3, 1e-3, False),
1727            forward_input=FunctionInput(make_input((4, 3, 2, 3))),
1728            desc='2d_no_affine_LN'),
1729    ]
1730
1731
1732def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs):
1733    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1734
1735    return [
1736        ModuleInput(
1737            constructor_input=FunctionInput(2.),
1738            forward_input=FunctionInput(make_input((4, 3, 2, 4))),
1739        ),
1740        ModuleInput(
1741            constructor_input=FunctionInput(2.),
1742            forward_input=FunctionInput(make_input(())),
1743            desc='scalar',
1744        ),
1745        ModuleInput(
1746            constructor_input=FunctionInput(),
1747            forward_input=FunctionInput(make_input(4)),
1748            reference_fn=no_batch_dim_reference_fn,
1749            desc='no_batch_dim',
1750        )
1751    ]
1752
1753
1754def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, training, **kwargs):
1755    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1756
1757    return [
1758        ModuleInput(
1759            constructor_input=FunctionInput(),
1760            forward_input=FunctionInput(make_input(4)),
1761            reference_fn=no_batch_dim_reference_fn,
1762            desc='no_batch_dim',
1763        ),
1764        ModuleInput(
1765            constructor_input=FunctionInput(),
1766            forward_input=FunctionInput(make_input((2, 3, 2, 5))),
1767            desc='4d_input')
1768    ]
1769
1770
1771def module_inputs_torch_nn_Hardtanh(module_info, device, dtype, requires_grad, training, **kwargs):
1772    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1773
1774    return [
1775        ModuleInput(
1776            constructor_input=FunctionInput(),
1777            forward_input=FunctionInput(make_input((3, 2, 5))),
1778            reference_fn=lambda m, p, i: i.clamp(-1, 1),
1779        ),
1780        ModuleInput(
1781            constructor_input=FunctionInput(),
1782            forward_input=FunctionInput(make_input(())),
1783            reference_fn=lambda m, p, i: i.clamp(-1, 1),
1784            desc='scalar',
1785        ),
1786        ModuleInput(
1787            constructor_input=FunctionInput(),
1788            forward_input=FunctionInput(make_input(4)),
1789            reference_fn=no_batch_dim_reference_fn,
1790            desc='no_batch_dim',
1791        )
1792    ]
1793
1794
1795def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
1796    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1797    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
1798
1799    cases: List[Tuple[str, dict]] = [
1800        ('', {}),
1801        ('reduction_sum', {'reduction': 'sum'}),
1802        ('reduction_mean', {'reduction': 'mean'}),
1803        ('reduction_none', {'reduction': 'none'}),
1804        ('margin', {'margin': 0.5})
1805    ]
1806
1807    module_inputs = []
1808    for desc, constructor_kwargs in cases:
1809        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
1810            return hingeembeddingloss_reference(i, t, **constructor_kwargs)
1811
1812        module_inputs.append(
1813            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
1814                        forward_input=FunctionInput(make_input((10,)),
1815                                                    make_target((10,)).gt(0).to(dtype).mul_(2).sub_(1)),
1816                        desc=desc,
1817                        reference_fn=reference_fn)
1818        )
1819        module_inputs.append(
1820            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
1821                        forward_input=FunctionInput(make_input(()),
1822                                                    make_target(()).gt(0).to(dtype).mul_(2).sub_(1)),
1823                        desc=f'scalar_{desc}',
1824                        reference_fn=reference_fn)
1825        )
1826
1827    return module_inputs
1828
1829
1830def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs):
1831    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1832
1833    cases: List[Tuple[str, dict]] = [
1834        ('', {}),
1835        ('reduction_sum', {'reduction': 'sum'}),
1836        ('reduction_mean', {'reduction': 'mean'}),
1837        ('reduction_none', {'reduction': 'none'}),
1838    ]
1839
1840    module_inputs = []
1841    for desc, constructor_kwargs in cases:
1842        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
1843            return huberloss_reference(i, t, **constructor_kwargs)
1844
1845        module_inputs.append(
1846            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
1847                        forward_input=FunctionInput(make_input((5, 10)),
1848                                                    make_input((5, 10))),
1849                        desc=desc,
1850                        reference_fn=reference_fn)
1851        )
1852
1853    return module_inputs
1854
1855
1856def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_grad, training, **kwargs):
1857    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1858    lazy = kwargs.get('lazy', False)
1859    N = kwargs['N']
1860    num_features, eps, momentum, affine, track_running_stats = 3, 1e-3, 0.3, False, True
1861    input_no_batch_shape_dict = {1: (3, 15), 2: (3, 6, 6), 3: (3, 4, 4, 4)}
1862    input_no_batch_shape = input_no_batch_shape_dict[N]
1863    input_batch_shape = (4,) + input_no_batch_shape
1864
1865    return [
1866        ModuleInput(
1867            constructor_input=(
1868                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
1869            ),
1870            forward_input=FunctionInput(make_input(input_batch_shape))),
1871        ModuleInput(
1872            constructor_input=(
1873                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
1874                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
1875            ),
1876            forward_input=FunctionInput(make_input(input_batch_shape)),
1877            desc='tracking_stats'),
1878        ModuleInput(
1879            constructor_input=(
1880                FunctionInput(eps, momentum) if lazy else FunctionInput(num_features, eps, momentum)
1881            ),
1882            forward_input=FunctionInput(make_input(input_no_batch_shape)),
1883            reference_fn=no_batch_dim_reference_fn,
1884            desc='tracking_stats_no_batch_dim'),
1885        ModuleInput(
1886            constructor_input=(
1887                FunctionInput(eps, momentum, affine, track_running_stats) if lazy else
1888                FunctionInput(num_features, eps, momentum, affine, track_running_stats)
1889            ),
1890            forward_input=FunctionInput(make_input(input_no_batch_shape)),
1891            reference_fn=no_batch_dim_reference_fn,
1892            desc='no_batch_dim')
1893    ]
1894
1895def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs):
1896    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1897
1898    return [
1899        ModuleInput(
1900            constructor_input=FunctionInput([5], 1e-3),
1901            forward_input=FunctionInput(make_input((4, 5, 5))),
1902            desc='1d_elementwise_affine'),
1903        ModuleInput(
1904            constructor_input=FunctionInput([5], 1e-3),
1905            forward_input=FunctionInput(make_input((128, 5, 5))),
1906            desc='1d_elementwise_affine_large_batch'),
1907        ModuleInput(
1908            constructor_input=FunctionInput([5], 1e-3, False),
1909            forward_input=FunctionInput(make_input((4, 5, 5))),
1910            desc='1d_no_elementwise_affine'),
1911        ModuleInput(
1912            constructor_input=FunctionInput([2, 2, 5], 1e-3),
1913            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
1914            desc='3d_elementwise_affine'),
1915        ModuleInput(
1916            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
1917            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
1918            desc='3d_no_elementwise_affine'),
1919        ModuleInput(
1920            constructor_input=FunctionInput([5], 1e-3),
1921            forward_input=FunctionInput(make_input((0, 5))),
1922            desc='1d_empty_elementwise_affine'),
1923        ModuleInput(
1924            constructor_input=FunctionInput([2, 2, 5], 1e-3, elementwise_affine=True, bias=False),
1925            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
1926            desc='3d_elementwise_affine_no_bias'),
1927    ]
1928
1929def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, training, **kwargs):
1930    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1931
1932    def rms_norm_reference_fn(m, p, i):
1933        eps = m.eps
1934        if eps is None:
1935            eps = torch.finfo(i.dtype).eps
1936        ndim = i.ndim
1937        normalized_shape = m.normalized_shape
1938        weight = m.weight
1939        dims = [ndim - i - 1 for i in range(len(normalized_shape))]
1940        result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
1941        if weight is not None:
1942            result *= weight
1943        return result
1944
1945    return [
1946        ModuleInput(
1947            constructor_input=FunctionInput([5], 1e-3),
1948            forward_input=FunctionInput(make_input((4, 5, 5))),
1949            desc='1d_elementwise_affine',
1950            reference_fn=rms_norm_reference_fn),
1951        ModuleInput(
1952            constructor_input=FunctionInput([5], 1e-3),
1953            forward_input=FunctionInput(make_input((128, 5, 5))),
1954            desc='1d_elementwise_affine_large_batch',
1955            reference_fn=rms_norm_reference_fn),
1956        ModuleInput(
1957            constructor_input=FunctionInput([5], 1e-3, False),
1958            forward_input=FunctionInput(make_input((4, 5, 5))),
1959            desc='1d_no_elementwise_affine',
1960            reference_fn=rms_norm_reference_fn),
1961        ModuleInput(
1962            constructor_input=FunctionInput([2, 2, 5], 1e-3),
1963            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
1964            desc='3d_elementwise_affine',
1965            reference_fn=rms_norm_reference_fn),
1966        ModuleInput(
1967            constructor_input=FunctionInput([2, 2, 5], 1e-3, False),
1968            forward_input=FunctionInput(make_input((4, 2, 2, 5))),
1969            desc='3d_no_elementwise_affine',
1970            reference_fn=rms_norm_reference_fn),
1971        ModuleInput(
1972            constructor_input=FunctionInput([5], 1e-3),
1973            forward_input=FunctionInput(make_input((0, 5))),
1974            desc='1d_empty_elementwise_affine',
1975            reference_fn=rms_norm_reference_fn),
1976    ]
1977
1978
1979def module_inputs_torch_nn_LocalResponseNorm(module_info, device, dtype, requires_grad, training, **kwargs):
1980    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1981
1982    return [
1983        ModuleInput(
1984            constructor_input=FunctionInput(3,),
1985            forward_input=FunctionInput(make_input((1, 5, 7))),
1986            desc='1d'),
1987        ModuleInput(
1988            constructor_input=FunctionInput(2,),
1989            forward_input=FunctionInput(make_input((1, 5, 7, 7))),
1990            desc='2d_uneven_pad'),
1991        ModuleInput(
1992            constructor_input=FunctionInput(1, 1., 0.5, 2.),
1993            forward_input=FunctionInput(make_input((1, 5, 7, 7, 7))),
1994            desc='3d_custom_params'),
1995    ]
1996
1997
1998def module_inputs_torch_nn_LPPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
1999    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2000
2001    return [
2002        ModuleInput(
2003            constructor_input=FunctionInput(1.5, 2),
2004            forward_input=FunctionInput(make_input((1, 3, 7))),
2005            desc='norm'),
2006        ModuleInput(
2007            constructor_input=FunctionInput(2, 2, 3),
2008            forward_input=FunctionInput(make_input((1, 3, 7)))),
2009        ModuleInput(
2010            constructor_input=FunctionInput(2, 2, 3),
2011            forward_input=FunctionInput(make_input((3, 7))),
2012            reference_fn=no_batch_dim_reference_fn,
2013            desc='no_batch_dim'),
2014    ]
2015
2016
2017
2018def module_inputs_torch_nn_LPPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
2019    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2020
2021    return [
2022        ModuleInput(
2023            constructor_input=FunctionInput(2, 2, 2),
2024            forward_input=FunctionInput(make_input((1, 3, 7, 7)))),
2025        ModuleInput(
2026            constructor_input=FunctionInput(2, 2, 2),
2027            forward_input=FunctionInput(make_input((3, 7, 7))),
2028            reference_fn=no_batch_dim_reference_fn,
2029            desc='no_batch_dim'),
2030        ModuleInput(
2031            constructor_input=FunctionInput(1.5, 2),
2032            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
2033            desc='norm'),
2034    ]
2035
2036
2037def module_inputs_torch_nn_LPPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
2038    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2039
2040    return [
2041        ModuleInput(
2042            constructor_input=FunctionInput(2, 2, 2),
2043            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7)))),
2044        ModuleInput(
2045            constructor_input=FunctionInput(2, 2, 2),
2046            forward_input=FunctionInput(make_input((3, 7, 7, 7))),
2047            reference_fn=no_batch_dim_reference_fn,
2048            desc='no_batch_dim'),
2049        ModuleInput(
2050            constructor_input=FunctionInput(1.5, 2),
2051            forward_input=FunctionInput(make_input((1, 3, 7, 7, 7))),
2052            desc='norm'),
2053    ]
2054
2055
2056def module_inputs_torch_nn_MaxPool1d(module_info, device, dtype, requires_grad, training, **kwargs):
2057    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2058
2059    return [
2060        ModuleInput(
2061            constructor_input=FunctionInput(4),
2062            forward_input=FunctionInput(make_input((2, 10, 4))),
2063            desc='3d_input'),
2064        ModuleInput(
2065            constructor_input=FunctionInput(4, 4),
2066            forward_input=FunctionInput(make_input((2, 10, 4))),
2067            desc='stride'),
2068        ModuleInput(
2069            constructor_input=FunctionInput(4, return_indices=True),
2070            forward_input=FunctionInput(make_input((2, 10, 4))),
2071            desc='return_indices'),
2072    ]
2073
2074
2075def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
2076    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2077
2078    return [
2079        ModuleInput(
2080            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
2081            forward_input=FunctionInput(make_input((3, 7, 7))),
2082            desc='3d_input'),
2083        ModuleInput(
2084            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1)),
2085            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
2086            desc='4d_input'),
2087        ModuleInput(
2088            constructor_input=FunctionInput((3, 3), (2, 2), (1, 1), return_indices=True),
2089            forward_input=FunctionInput(make_input((1, 3, 7, 7))),
2090            desc='return_indices'),
2091    ]
2092
2093def module_inputs_torch_nn_MaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
2094    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2095
2096    return [
2097        ModuleInput(
2098            constructor_input=FunctionInput((2, 2, 2)),
2099            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5)))),
2100        ModuleInput(
2101            constructor_input=FunctionInput(2, (2, 2, 2)),
2102            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
2103            desc='stride'),
2104        ModuleInput(
2105            constructor_input=FunctionInput(2, 2, (1, 1, 1)),
2106            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
2107            desc='stride_padding'),
2108        ModuleInput(
2109            constructor_input=FunctionInput(2, 2, (1, 1, 1), return_indices=True),
2110            forward_input=FunctionInput(make_input((2, 3, 5, 5, 5))),
2111            desc='return_indices'),
2112    ]
2113
2114
2115def module_inputs_torch_nn_FractionalMaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs):
2116    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2117
2118    def make_random_samples():
2119        return torch.empty((1, 3, 2), dtype=torch.double, device=device).uniform_()
2120
2121    return [
2122        ModuleInput(
2123            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
2124            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
2125            desc='ratio'),
2126        ModuleInput(
2127            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
2128            forward_input=FunctionInput(make_input((1, 3, 7, 6))),
2129            desc='size'),
2130        ModuleInput(
2131            constructor_input=FunctionInput(
2132                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
2133            ),
2134            forward_input=FunctionInput(make_input((1, 3, 5, 7))),
2135            desc='ratio_return_indices'),
2136        ModuleInput(
2137            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
2138            forward_input=FunctionInput(make_input((3, 5, 7))),
2139            reference_fn=no_batch_dim_reference_fn,
2140            desc='ratio_no_batch_dim'),
2141        ModuleInput(
2142            constructor_input=FunctionInput((2, 3), output_size=(4, 3), _random_samples=make_random_samples()),
2143            forward_input=FunctionInput(make_input((3, 7, 6))),
2144            reference_fn=no_batch_dim_reference_fn,
2145            desc='size_no_batch_dim'),
2146    ]
2147
2148
2149def module_inputs_torch_nn_FractionalMaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs):
2150    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2151
2152    def make_random_samples():
2153        return torch.empty((2, 4, 3), dtype=torch.double, device=device).uniform_()
2154
2155    return [
2156        ModuleInput(
2157            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
2158            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
2159            desc='ratio'),
2160        ModuleInput(
2161            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
2162            forward_input=FunctionInput(make_input((2, 4, 7, 7, 7))),
2163            desc='size'),
2164        ModuleInput(
2165            constructor_input=FunctionInput((4, 2, 3), output_size=(10, 3, 2), _random_samples=make_random_samples()),
2166            forward_input=FunctionInput(make_input((2, 4, 16, 7, 5))),
2167            desc='asymsize'),
2168        ModuleInput(
2169            constructor_input=FunctionInput(
2170                2, output_ratio=0.5, _random_samples=make_random_samples(), return_indices=True
2171            ),
2172            forward_input=FunctionInput(make_input((2, 4, 5, 5, 5))),
2173            desc='ratio_return_indices'),
2174        ModuleInput(
2175            constructor_input=FunctionInput(2, output_ratio=0.5, _random_samples=make_random_samples()),
2176            forward_input=FunctionInput(make_input((4, 5, 5, 5))),
2177            reference_fn=no_batch_dim_reference_fn,
2178            desc='ratio_no_batch_dim'),
2179        ModuleInput(
2180            constructor_input=FunctionInput((2, 2, 2), output_size=(4, 4, 4), _random_samples=make_random_samples()),
2181            forward_input=FunctionInput(make_input((4, 7, 7, 7))),
2182            reference_fn=no_batch_dim_reference_fn,
2183            desc='size_no_batch_dim'),
2184    ]
2185
2186
2187def module_inputs_torch_nn_Sigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
2188    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2189
2190    return [
2191        ModuleInput(
2192            constructor_input=FunctionInput(),
2193            forward_input=FunctionInput(make_input(())),
2194            desc='scalar'
2195        ),
2196        ModuleInput(
2197            constructor_input=FunctionInput(),
2198            forward_input=FunctionInput(make_input(4)),
2199            reference_fn=no_batch_dim_reference_fn,
2200            desc='no_batch_dim',
2201        ),
2202        ModuleInput(
2203            constructor_input=FunctionInput(),
2204            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
2205            desc='channels_last_mem_format'
2206        ),
2207        ModuleInput(
2208            constructor_input=FunctionInput(),
2209            forward_input=FunctionInput(make_input((2, 3, 3, 4, 5))),
2210            desc='channels_last_3d_mem_format'
2211        )
2212    ]
2213
2214
2215def module_inputs_torch_nn_LogSigmoid(module_info, device, dtype, requires_grad, training, **kwargs):
2216    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2217
2218    return [
2219        ModuleInput(
2220            constructor_input=FunctionInput(),
2221            forward_input=FunctionInput(make_input(())),
2222            reference_fn=lambda m, p, i: i.sigmoid().log(),
2223            desc='scalar'
2224        ),
2225        ModuleInput(
2226            constructor_input=FunctionInput(),
2227            forward_input=FunctionInput(make_input((2, 3, 4))),
2228            reference_fn=lambda m, p, i: i.sigmoid().log(),
2229        ),
2230        ModuleInput(
2231            constructor_input=FunctionInput(),
2232            forward_input=FunctionInput(make_input(4)),
2233            reference_fn=no_batch_dim_reference_fn,
2234            desc='no_batch_dim',
2235        ),
2236    ]
2237
2238
2239def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, requires_grad, training, **kwargs):
2240    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2241    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
2242
2243    cases: List[Tuple[str, dict]] = [
2244        ('', {}),
2245        ('reduction_sum', {'reduction': 'sum'}),
2246        ('reduction_mean', {'reduction': 'mean'}),
2247        ('reduction_none', {'reduction': 'none'}),
2248        ('margin', {'margin': 0.5})
2249    ]
2250
2251    module_inputs = []
2252    for desc, constructor_kwargs in cases:
2253        def reference_fn(m, p, i1, i2, t, constructor_kwargs=constructor_kwargs):
2254            return marginrankingloss_reference(i1, i2, t, **constructor_kwargs)
2255
2256        module_inputs.append(
2257            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
2258                        forward_input=FunctionInput(make_input((50,)), make_input((50,)),
2259                                                    make_target((50,)).sign()),
2260                        desc=desc,
2261                        reference_fn=reference_fn)
2262        )
2263
2264    return module_inputs
2265
2266
2267def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
2268    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2269    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
2270
2271    cases: List[Tuple[str, dict]] = [
2272        ('', {}),
2273        ('reduction_sum', {'reduction': 'sum'}),
2274        ('reduction_mean', {'reduction': 'mean'}),
2275        ('reduction_none', {'reduction': 'none'}),
2276    ]
2277
2278    module_inputs = []
2279    for desc, constructor_kwargs in cases:
2280        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
2281            return multilabelmarginloss_reference(i, t, **constructor_kwargs)
2282
2283        module_inputs.append(
2284            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
2285                        forward_input=FunctionInput(make_input((10,)),
2286                                                    make_target((10), low=0, high=10)),
2287                        desc=f'1d_{desc}',
2288                        reference_fn=reference_fn)
2289        )
2290
2291        module_inputs.append(
2292            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
2293                        forward_input=FunctionInput(make_input((5, 10)),
2294                                                    make_target((5, 10), low=0, high=10)),
2295                        desc=desc,
2296                        reference_fn=reference_fn)
2297        )
2298
2299    return module_inputs
2300
2301
2302def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
2303    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2304    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
2305    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
2306
2307    cases: List[Tuple[str, dict]] = [
2308        ('', {}),
2309        ('reduction_sum', {'reduction': 'sum'}),
2310        ('reduction_mean', {'reduction': 'mean'}),
2311        ('reduction_none', {'reduction': 'none'}),
2312        ('p', {'p': 2}),
2313        ('margin', {'margin': 0.5}),
2314        ('weights', {'weight': make_weight(10)})
2315    ]
2316
2317    module_inputs = []
2318    for desc, constructor_kwargs in cases:
2319        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
2320            return multimarginloss_reference(i, t, **constructor_kwargs)
2321
2322        module_inputs.append(
2323            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
2324                        forward_input=FunctionInput(make_input((5, 10)),
2325                                                    make_target((5), low=0, high=10)),
2326                        desc=desc,
2327                        reference_fn=reference_fn)
2328        )
2329
2330    return module_inputs
2331
2332
2333def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
2334    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2335    make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
2336    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
2337
2338    cases: List[Tuple[str, dict]] = [
2339        ('', {}),
2340        ('reduction_sum', {'reduction': 'sum'}),
2341        ('reduction_mean', {'reduction': 'mean'}),
2342        ('reduction_none', {'reduction': 'none'}),
2343        ('weight', {'weight': make_weight(10)}),
2344    ]
2345
2346    def multilabelsoftmargin_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
2347        result = t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()
2348        if weight is not None:
2349            result *= weight
2350        result = (-result).sum(i.dim() - 1) / i.size(-1)
2351
2352        if reduction == 'none':
2353            return result
2354        elif reduction == 'mean':
2355            return result.mean()
2356        else:
2357            return result.sum()
2358
2359    module_inputs = []
2360    for desc, constructor_kwargs in cases:
2361        module_inputs.append(
2362            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
2363                        forward_input=FunctionInput(make_input((5, 10)),
2364                                                    make_target((5, 10), low=0, high=2)),
2365                        desc=desc,
2366                        reference_fn=partial(multilabelsoftmargin_loss_reference_fn, **constructor_kwargs))
2367        )
2368
2369    return module_inputs
2370
2371
2372def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_grad, training, **kwargs):
2373    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2374    make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
2375
2376    cases: List[Tuple[str, dict]] = [
2377        ('', {}),
2378        ('reduction_sum', {'reduction': 'sum'}),
2379        ('reduction_mean', {'reduction': 'mean'}),
2380        ('reduction_none', {'reduction': 'none'}),
2381    ]
2382
2383    module_inputs = []
2384    for desc, constructor_kwargs in cases:
2385        def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs):
2386            return softmarginloss_reference(i, t, **constructor_kwargs)
2387
2388        module_inputs.append(
2389            ModuleInput(constructor_input=FunctionInput(**constructor_kwargs),
2390                        forward_input=FunctionInput(make_input((5, 5)),
2391                                                    make_target((5, 5)).sign()),
2392                        desc=desc,
2393                        reference_fn=reference_fn)
2394        )
2395
2396    return module_inputs
2397
2398
2399def module_inputs_torch_nn_TransformerEncoder(module_info, device, dtype, requires_grad, training, **kwargs):
2400    # Reuse the TransformerEncoderLayer samples since the forward args are nearly the same.
2401    samples = []
2402    for layer_module_input in module_inputs_torch_nn_TransformerEncoderLayer(
2403            None, device, dtype, requires_grad, training):
2404        # Construct a TransformerEncoderLayer object to pass to TransformerEncoder.
2405        l_args, l_kwargs = (layer_module_input.constructor_input.args,
2406                            layer_module_input.constructor_input.kwargs)
2407        l_kwargs['device'] = device
2408        l_kwargs['dtype'] = dtype
2409        encoder_layer = torch.nn.TransformerEncoderLayer(*l_args, **l_kwargs)
2410        num_layers = 2
2411        # Note: TransformerEncoderLayer takes a "src_mask" while
2412        # TransformerEncoder takes a "mask"; rename kwarg appropriately.
2413        forward_input = layer_module_input.forward_input
2414        if 'src_mask' in forward_input.kwargs:
2415            forward_input.kwargs['mask'] = forward_input.kwargs['src_mask']
2416            del forward_input.kwargs['src_mask']
2417        samples.append(ModuleInput(
2418            constructor_input=FunctionInput(encoder_layer, num_layers),
2419            forward_input=forward_input,
2420            desc=layer_module_input.desc
2421        ))
2422    return samples
2423
2424def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
2425    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2426
2427    samples = [
2428        ModuleInput(
2429            constructor_input=FunctionInput(4, 2, 16, 0.0),
2430            forward_input=FunctionInput(
2431                make_input((2, 3, 4))
2432            ),
2433            desc='relu_activation'
2434        ),
2435        ModuleInput(
2436            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
2437            forward_input=FunctionInput(
2438                make_input((2, 3, 4))
2439            ),
2440            desc='gelu_activation'
2441        ),
2442        ModuleInput(
2443            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
2444            forward_input=FunctionInput(
2445                make_input((2, 3, 4))
2446            ),
2447            desc='no_bias'
2448        ), ]
2449
2450    # Samples below are for validating the no-batch-dim support.
2451    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
2452    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
2453    for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \
2454            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
2455        samples.append(
2456            ModuleInput(
2457                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
2458                                                dropout=0.0, batch_first=batch_first,
2459                                                norm_first=norm_first, bias=bias),
2460                forward_input=FunctionInput(
2461                    make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
2462                ),
2463                reference_fn=partial(no_batch_dim_reference_fn,
2464                                     batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}),
2465                desc=f'no_batch_dim_batch_first_{batch_first}'
2466            ))
2467
2468    # Samples below where we pass reference_fn are for validating the fast path,
2469    # since the fast path requires no_grad mode, we run the fast path in .eval()
2470    # and no_grad() in the reference_fn and verify that against the results in train mode.
2471    def fast_path_reference_fn(module, parameters, *args, **kwargs):
2472        assert module.training
2473        module.train(False)
2474        with torch.no_grad():
2475            output = module(*args, **kwargs)
2476        module.train(True)
2477        return output
2478
2479    if training:
2480        for norm_first, bias in itertools.product((True, False), (True, False)):
2481            samples.append(
2482                ModuleInput(
2483                    constructor_input=FunctionInput(
2484                        4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias
2485                    ),
2486                    forward_input=FunctionInput(
2487                        make_input((2, 3, 4)),
2488                    ),
2489                    # fastpath doesn't run when bias=False
2490                    reference_fn=fast_path_reference_fn if bias else None,
2491                    desc=f'fastpath_{bias}_norm_first_{norm_first}'
2492                )
2493            )
2494
2495    return samples
2496
2497
2498def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, requires_grad, training, **kwargs):
2499    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2500
2501    samples = [
2502        ModuleInput(
2503            constructor_input=FunctionInput(4, 2, 16, 0.0),
2504            forward_input=FunctionInput(
2505                make_input((2, 3, 4)), make_input((2, 3, 4))
2506            ),
2507            desc='relu_activation'
2508        ),
2509        ModuleInput(
2510            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
2511            forward_input=FunctionInput(
2512                make_input((2, 3, 4)), make_input((2, 3, 4))
2513            ),
2514            desc='gelu_activation'
2515        ),
2516        ModuleInput(
2517            constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False),
2518            forward_input=FunctionInput(
2519                make_input((2, 3, 4)), make_input((2, 3, 4))
2520            ),
2521            desc='no_bias'
2522        ), ]
2523
2524    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
2525    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
2526    for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \
2527            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
2528        # Using same mask for tgt and memory
2529        memory_mask = tgt_mask
2530        memory_key_padding_mask = tgt_key_padding_mask
2531        samples.append(
2532            ModuleInput(
2533                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
2534                                                dropout=0.0, batch_first=batch_first,
2535                                                norm_first=norm_first, bias=bias),
2536                forward_input=FunctionInput(
2537                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask,
2538                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
2539                ),
2540                reference_fn=partial(no_batch_dim_reference_fn,
2541                                     batch_first=batch_first,
2542                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
2543                desc=f'no_batch_dim_batch_first_{batch_first}'
2544            ))
2545        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
2546        if not batch_first:
2547            src, tgt = src.transpose(0, 1), tgt.transpose(0, 1)
2548        if tgt_key_padding_mask is not None:
2549            memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2
2550        samples.append(
2551            ModuleInput(
2552                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
2553                                                dropout=0.0, batch_first=batch_first,
2554                                                norm_first=norm_first, bias=bias),
2555                forward_input=FunctionInput(
2556                    src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask,
2557                    tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
2558                ),
2559                desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}'
2560            ))
2561
2562    return samples
2563
2564
2565def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad, training, **kwargs):
2566    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2567    samples = []
2568    # Samples below are for validating the no-batch-dim support.
2569    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
2570    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3)))
2571    for mask, key_padding_mask, norm_first, bias, batch_first in \
2572            itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)):
2573        # Using same mask for tgt and memory
2574        src_mask , tgt_mask = (mask,) * 2
2575        src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2
2576        samples.append(
2577            ModuleInput(
2578                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
2579                                                num_encoder_layers=1, num_decoder_layers=1,
2580                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
2581                forward_input=FunctionInput(
2582                    make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask,
2583                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
2584                ),
2585                reference_fn=partial(no_batch_dim_reference_fn,
2586                                     batch_first=batch_first,
2587                                     kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}),
2588                desc=f'no_batch_dim_batch_first_{batch_first}'
2589            ))
2590
2591        src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4))
2592        if not batch_first:
2593            src = src.transpose(0, 1)
2594            tgt = tgt.transpose(0, 1)
2595        if key_padding_mask is not None:
2596            src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2
2597
2598        samples.append(
2599            ModuleInput(
2600                constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8,
2601                                                num_encoder_layers=1, num_decoder_layers=1,
2602                                                dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias),
2603                forward_input=FunctionInput(
2604                    src, tgt, tgt_mask=tgt_mask, src_mask=src_mask,
2605                    tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask
2606                ),
2607            ))
2608    return samples
2609
2610
2611def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs):
2612    make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
2613    return [
2614        ModuleInput(
2615            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
2616            forward_input=FunctionInput(make_empty(2, 3).random_(4))
2617        ),
2618        ModuleInput(
2619            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
2620            forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
2621            desc='discontiguous'
2622        ),
2623    ]
2624
2625
2626def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs):
2627    # Currently all samples below are for validating the no-batch-dim support.
2628    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2629    samples = []
2630    bool_vals = (True, False)
2631    key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool))
2632    attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3, 3)))
2633    products = itertools.product(bool_vals, bool_vals, bool_vals, key_padding_masks, attn_masks)
2634    for bias, add_bias_kv, add_zero_attn, key_padding_mask, attn_mask in products:
2635        samples.append(
2636            ModuleInput(
2637                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=True,
2638                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
2639                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
2640                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
2641                reference_fn=no_batch_dim_reference_mha,
2642            )
2643        )
2644        samples.append(
2645            ModuleInput(
2646                constructor_input=FunctionInput(embed_dim=3, num_heads=3, batch_first=False,
2647                                                bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn),
2648                forward_input=FunctionInput(make_input((3, 3)), make_input((3, 3)), make_input((3, 3)),
2649                                            key_padding_mask=key_padding_mask, attn_mask=attn_mask),
2650                reference_fn=partial(no_batch_dim_reference_mha, batch_first=False),
2651            )
2652        )
2653
2654    return samples
2655
2656
2657def module_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
2658    # Currently all samples below are for validating the no-batch-dim support.
2659    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2660    samples = [
2661        ModuleInput(
2662            constructor_input=FunctionInput(5, 10),
2663            forward_input=FunctionInput(make_input(5), make_input(10)),
2664            reference_fn=no_batch_dim_reference_fn,
2665        ),
2666        ModuleInput(
2667            constructor_input=FunctionInput(5, 10, bias=True),
2668            forward_input=FunctionInput(make_input(5), make_input(10)),
2669            reference_fn=no_batch_dim_reference_fn,
2670        )
2671    ]
2672
2673    is_rnn = kwargs.get('is_rnn', False)
2674    if is_rnn:
2675        # RNN also supports `nonlinearity` argument.
2676        # `tanh` is the default, so we check with `relu`
2677        samples.append(
2678            ModuleInput(
2679                constructor_input=FunctionInput(5, 10, bias=True, nonlinearity='relu'),
2680                forward_input=FunctionInput(make_input(5), make_input(10)),
2681                reference_fn=no_batch_dim_reference_fn,
2682            )
2683        )
2684
2685    return samples
2686
2687
2688def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
2689    # Currently all samples below are for validating the no-batch-dim support.
2690    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2691    samples = (
2692        ModuleInput(
2693            constructor_input=FunctionInput(5, 10),
2694            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
2695            reference_fn=no_batch_dim_reference_lstmcell,
2696        ),
2697        ModuleInput(
2698            constructor_input=FunctionInput(5, 10, bias=True),
2699            forward_input=FunctionInput(make_input(5), (make_input(10), make_input(10))),
2700            reference_fn=no_batch_dim_reference_lstmcell,
2701        ),
2702    )
2703
2704    return samples
2705
2706def make_packed_sequence(inp, batch_sizes):
2707    required_grad = inp.requires_grad
2708    inp.requires_grad_(False)  # user won't have access to inp so won't be able to get its grads
2709    seq = pack_padded_sequence(inp, batch_sizes)
2710    seq.data.requires_grad_(required_grad)
2711    return seq
2712
2713
2714def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs):
2715    # Currently all samples below are for validating the no-batch-dim support.
2716    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2717    is_rnn = kwargs['is_rnn']
2718    nonlinearity = ('relu', 'tanh')
2719    bias = (False, True)
2720    batch_first = (False, True)
2721    bidirectional = (False, True)
2722
2723    samples = []
2724    if is_rnn:
2725        prod_gen = product(nonlinearity, bias, batch_first, bidirectional)
2726    else:
2727        prod_gen = product(bias, batch_first, bidirectional)
2728
2729    for args in prod_gen:
2730        if is_rnn:
2731            nl, b, b_f, bidir = args
2732        else:
2733            b, b_f, bidir = args
2734
2735        cons_args = {'input_size': 2, 'hidden_size': 2, 'num_layers': 2,
2736                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
2737        cons_args_hidden = {'input_size': 2, 'hidden_size': 3, 'num_layers': 2,
2738                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
2739
2740        if is_rnn:
2741            cons_args['nonlinearity'] = nl
2742            cons_args_hidden['nonlinearity'] = nl
2743        samples.append(
2744            ModuleInput(
2745                constructor_input=FunctionInput(**cons_args),
2746                forward_input=FunctionInput(make_input((3, 2))),
2747                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
2748            )
2749        )
2750        samples.append(
2751            ModuleInput(
2752                constructor_input=FunctionInput(**cons_args_hidden),
2753                forward_input=FunctionInput(make_input((3, 2)), make_input((4 if bidir else 2, 3))),
2754                reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
2755            )
2756        )
2757        if with_packed_sequence:
2758            samples.append(
2759                ModuleInput(
2760                    constructor_input=FunctionInput(**cons_args),
2761                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))),
2762                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
2763                )
2764            )
2765            samples.append(
2766                ModuleInput(
2767                    constructor_input=FunctionInput(**cons_args),
2768                    forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))),
2769                    reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
2770                )
2771            )
2772
2773    return samples
2774
2775
2776def module_inputs_torch_nn_LSTM(module_info, device, dtype, requires_grad, training, **kwargs):
2777    # Currently all samples below are for validating the no-batch-dim support.
2778    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2779    bias = (False, True)
2780    batch_first = (False, True)
2781    bidirectional = (False, True)
2782    proj_sizes = (0, 2)
2783
2784    samples = []
2785    prod_gen = product(bias, batch_first, bidirectional, proj_sizes)
2786
2787    for args in prod_gen:
2788        b, b_f, bidir, proj_size = args
2789        hidden_size = 3
2790        cons_args = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
2791                     'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
2792        cons_args_hidden = {'input_size': 2, 'hidden_size': hidden_size, 'num_layers': 2, 'proj_size': proj_size,
2793                            'batch_first': b_f, 'bias': b, 'bidirectional': bidir}
2794
2795        samples.append(
2796            ModuleInput(
2797                constructor_input=FunctionInput(**cons_args),
2798                forward_input=FunctionInput(make_input((2, 2))),
2799                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
2800            )
2801        )
2802
2803        h_out = proj_size if proj_size > 0 else hidden_size
2804        hx = (make_input((4 if bidir else 2, h_out)), make_input((4 if bidir else 2, hidden_size)))
2805        samples.append(
2806            ModuleInput(
2807                constructor_input=FunctionInput(**cons_args_hidden),
2808                forward_input=FunctionInput(make_input((3, 2)), hx),
2809                reference_fn=partial(no_batch_dim_reference_lstm, batch_first=b_f),
2810            )
2811        )
2812
2813
2814    return samples
2815
2816
2817
2818def module_inputs_torch_nn_ReflectionPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
2819    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2820
2821    return [
2822        ModuleInput(
2823            constructor_input=FunctionInput(1),
2824            forward_input=FunctionInput(make_input((2, 3))),
2825            reference_fn=no_batch_dim_reference_fn,
2826        ),
2827        ModuleInput(
2828            constructor_input=FunctionInput((1, 2)),
2829            forward_input=FunctionInput(make_input((2, 3, 4))),
2830        ),
2831    ]
2832
2833def module_inputs_torch_nn_ReflectionPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
2834    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2835
2836    return [
2837        ModuleInput(
2838            constructor_input=FunctionInput(1),
2839            forward_input=FunctionInput(make_input((3, 4, 5))),
2840            reference_fn=no_batch_dim_reference_fn,
2841        ),
2842        ModuleInput(
2843            constructor_input=FunctionInput((1, 2, 3, 4)),
2844            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
2845        ),
2846    ]
2847
2848def module_inputs_torch_nn_ReflectionPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
2849    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2850
2851    return [
2852        ModuleInput(
2853            constructor_input=FunctionInput(1),
2854            forward_input=FunctionInput(make_input((2, 3, 4, 5))),
2855            reference_fn=no_batch_dim_reference_fn
2856        ),
2857        ModuleInput(
2858            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
2859            forward_input=FunctionInput(make_input((3, 3, 3, 3, 3))),
2860        ),
2861    ]
2862
2863def module_inputs_torch_nn_ReplicationPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
2864    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2865
2866    return [
2867        ModuleInput(
2868            constructor_input=FunctionInput(1),
2869            forward_input=FunctionInput(make_input((3, 4))),
2870            reference_fn=no_batch_dim_reference_fn
2871        ),
2872        ModuleInput(
2873            constructor_input=FunctionInput((1, 2)),
2874            forward_input=FunctionInput(make_input((3, 4, 5))),
2875        ),
2876    ]
2877
2878def module_inputs_torch_nn_ReplicationPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
2879    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2880
2881    return [
2882        ModuleInput(
2883            constructor_input=FunctionInput(1),
2884            forward_input=FunctionInput(make_input((3, 4, 5))),
2885            reference_fn=no_batch_dim_reference_fn,
2886        ),
2887        ModuleInput(
2888            constructor_input=FunctionInput((1, 2, 3, 4)),
2889            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
2890        ),
2891    ]
2892
2893def module_inputs_torch_nn_ReplicationPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
2894    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2895
2896    return [
2897        ModuleInput(
2898            constructor_input=FunctionInput(1),
2899            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
2900            reference_fn=no_batch_dim_reference_fn,
2901        ),
2902        ModuleInput(
2903            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
2904            forward_input=FunctionInput(make_input((3, 4, 5, 6, 7))),
2905        ),
2906    ]
2907
2908def module_inputs_torch_nn_ZeroPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
2909    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2910
2911    return [
2912        ModuleInput(
2913            constructor_input=FunctionInput(1),
2914            forward_input=FunctionInput(make_input((3, 4))),
2915            reference_fn=no_batch_dim_reference_fn,
2916        ),
2917        ModuleInput(
2918            constructor_input=FunctionInput((1, 2)),
2919            forward_input=FunctionInput(make_input((3, 4, 5))),
2920        ),
2921    ]
2922
2923def module_inputs_torch_nn_ZeroPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
2924    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2925
2926    return [
2927        ModuleInput(
2928            constructor_input=FunctionInput(1),
2929            forward_input=FunctionInput(make_input((1, 2, 3))),
2930            reference_fn=no_batch_dim_reference_fn
2931        ),
2932        ModuleInput(
2933            constructor_input=FunctionInput((1, 2, 3, 4)),
2934            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
2935        ),
2936    ]
2937
2938def module_inputs_torch_nn_ZeroPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
2939    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2940
2941    return [
2942        ModuleInput(
2943            constructor_input=FunctionInput(1),
2944            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
2945            reference_fn=no_batch_dim_reference_fn,
2946        ),
2947        ModuleInput(
2948            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6)),
2949            forward_input=FunctionInput(make_input((1, 2, 3, 4, 5))),
2950        ),
2951    ]
2952
2953def module_inputs_torch_nn_ConstantPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
2954    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2955
2956    return [
2957        ModuleInput(
2958            constructor_input=FunctionInput(1, 2),
2959            forward_input=FunctionInput(make_input((3, 4))),
2960            reference_fn=no_batch_dim_reference_fn,
2961        ),
2962        ModuleInput(
2963            constructor_input=FunctionInput((1, 2), 3),
2964            forward_input=FunctionInput(make_input((3, 4, 5))),
2965        ),
2966    ]
2967
2968def module_inputs_torch_nn_ConstantPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
2969    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2970
2971    return [
2972        ModuleInput(
2973            constructor_input=FunctionInput(1, 3),
2974            forward_input=FunctionInput(make_input((3, 4, 5))),
2975            reference_fn=no_batch_dim_reference_fn
2976        ),
2977        ModuleInput(
2978            constructor_input=FunctionInput((1, 2, 3, 4), 5),
2979            forward_input=FunctionInput(make_input((1, 2, 3, 4))),
2980        ),
2981    ]
2982
2983def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
2984    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2985
2986    return [
2987        ModuleInput(
2988            constructor_input=FunctionInput(1, 3),
2989            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
2990            reference_fn=no_batch_dim_reference_fn,
2991        ),
2992        ModuleInput(
2993            constructor_input=FunctionInput((1, 2, 3, 4, 5, 6), 7),
2994            forward_input=FunctionInput(make_input((1, 2, 1, 2, 1))),
2995        ),
2996    ]
2997
2998def module_inputs_torch_nn_CircularPad1d(module_info, device, dtype, requires_grad, training, **kwargs):
2999    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3000
3001    def padding1d_circular_ref(inp, pad):
3002        r""" input:
3003                [[[0., 1., 2.],
3004                  [3., 4., 5.]]]
3005                pad: (1, 2)
3006                output:
3007                    [[[2., 0., 1., 2., 0., 1.],
3008                      [5., 3., 4., 5., 3., 4.]]]
3009            """
3010        return torch.cat([inp[:, :, -pad[0]:], inp, inp[:, :, :pad[1]]], dim=2)
3011
3012    return [
3013        ModuleInput(
3014            constructor_input=FunctionInput(1),
3015            forward_input=FunctionInput(make_input((3, 4))),
3016            reference_fn=no_batch_dim_reference_fn
3017        ),
3018        ModuleInput(
3019            constructor_input=FunctionInput((1, 2)),
3020            forward_input=FunctionInput(make_input((1, 2, 3))),
3021            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
3022        ),
3023        ModuleInput(
3024            constructor_input=FunctionInput((3, 1)),
3025            forward_input=FunctionInput(make_input((1, 2, 3))),
3026            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
3027        ),
3028        ModuleInput(
3029            constructor_input=FunctionInput((3, 3)),
3030            forward_input=FunctionInput(make_input((1, 2, 3))),
3031            reference_fn=lambda m, p, i: padding1d_circular_ref(i, m.padding),
3032        ),
3033    ]
3034
3035def module_inputs_torch_nn_CircularPad2d(module_info, device, dtype, requires_grad, training, **kwargs):
3036    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3037
3038    def padding2d_circular_ref(inp, pad):
3039        r"""input:
3040                [[[[0., 1., 2],
3041                   [3., 4., 5.]]]]
3042                pad: (1, 2, 2, 1)
3043        output:
3044            [[[[2., 0., 1., 2., 0., 1.],
3045               [5., 3., 4., 5., 3., 4.],
3046               [2., 0., 1., 2., 0., 1.],
3047               [5., 3., 4., 5., 3., 4.],
3048               [2., 0., 1., 2., 0., 1.]]]]
3049        """
3050        inp = torch.cat([inp[:, :, -pad[2]:], inp, inp[:, :, :pad[3]]], dim=2)
3051        return torch.cat([inp[:, :, :, -pad[0]:], inp, inp[:, :, :, :pad[1]]], dim=3)
3052
3053    return [
3054        ModuleInput(
3055            constructor_input=FunctionInput(1),
3056            forward_input=FunctionInput(make_input((3, 4, 5))),
3057            reference_fn=no_batch_dim_reference_fn,
3058        ),
3059        ModuleInput(
3060            constructor_input=FunctionInput((1, 2, 2, 1)),
3061            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
3062            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
3063        ),
3064        ModuleInput(
3065            constructor_input=FunctionInput((2, 3, 2, 2)),
3066            forward_input=FunctionInput(make_input((1, 1, 2, 3))),
3067            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
3068        ),
3069        ModuleInput(
3070            constructor_input=FunctionInput((3, 3, 3, 1)),
3071            forward_input=FunctionInput(make_input((1, 1, 3, 3))),
3072            reference_fn=lambda m, p, i: padding2d_circular_ref(i, m.padding),
3073        ),
3074    ]
3075
3076def module_inputs_torch_nn_CircularPad3d(module_info, device, dtype, requires_grad, training, **kwargs):
3077    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3078
3079
3080    def padding3d_circular_ref(inp, pad):
3081        r"""input:
3082                [[[[[ 0.,  1.,  2.],
3083                    [ 3.,  4.,  5.]],
3084                   [[ 6.,  7.,  8.],
3085                    [ 9., 10., 11.]]]]]
3086            pad: (1, 2, 2, 1, 1, 2)
3087            output: [[[[[ 8.,  6.,  7.,  8.,  6.,  7.],
3088                        [11.,  9., 10., 11.,  9., 10.],
3089                        [ 8.,  6.,  7.,  8.,  6.,  7.],
3090                        [11.,  9., 10., 11.,  9., 10.],
3091                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
3092
3093                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
3094                        [ 5.,  3.,  4.,  5.,  3.,  4.],
3095                        [ 2.,  0.,  1.,  2.,  0.,  1.],
3096                        [ 5.,  3.,  4.,  5.,  3.,  4.],
3097                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
3098
3099                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
3100                        [11.,  9., 10., 11.,  9., 10.],
3101                        [ 8.,  6.,  7.,  8.,  6.,  7.],
3102                        [11.,  9., 10., 11.,  9., 10.],
3103                        [ 8.,  6.,  7.,  8.,  6.,  7.]],
3104
3105                       [[ 2.,  0.,  1.,  2.,  0.,  1.],
3106                        [ 5.,  3.,  4.,  5.,  3.,  4.],
3107                        [ 2.,  0.,  1.,  2.,  0.,  1.],
3108                        [ 5.,  3.,  4.,  5.,  3.,  4.],
3109                        [ 2.,  0.,  1.,  2.,  0.,  1.]],
3110
3111                       [[ 8.,  6.,  7.,  8.,  6.,  7.],
3112                        [11.,  9., 10., 11.,  9., 10.],
3113                        [ 8.,  6.,  7.,  8.,  6.,  7.],
3114                        [11.,  9., 10., 11.,  9., 10.],
3115                        [ 8.,  6.,  7.,  8.,  6.,  7.]]]]]
3116        """
3117        inp = torch.cat([inp[:, :, -pad[4]:], inp, inp[:, :, :pad[5]]], dim=2)
3118        inp = torch.cat([inp[:, :, :, -pad[2]:], inp, inp[:, :, :, :pad[3]]], dim=3)
3119        return torch.cat([inp[:, :, :, :, -pad[0]:], inp, inp[:, :, :, :, :pad[1]]], dim=4)
3120
3121    return [
3122        ModuleInput(
3123            constructor_input=FunctionInput(1),
3124            forward_input=FunctionInput(make_input((3, 4, 5, 6))),
3125            reference_fn=no_batch_dim_reference_fn,
3126        ),
3127        ModuleInput(
3128            constructor_input=FunctionInput((1, 2, 1, 2, 1, 2)),
3129            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
3130            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
3131        ),
3132        ModuleInput(
3133            constructor_input=FunctionInput((3, 2, 2, 1, 1, 2)),
3134            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
3135            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
3136        ),
3137        ModuleInput(
3138            constructor_input=FunctionInput((3, 3, 2, 1, 2, 2)),
3139            forward_input=FunctionInput(make_input((1, 1, 2, 2, 3))),
3140            reference_fn=lambda m, p, i: padding3d_circular_ref(i, m.padding)
3141        ),
3142    ]
3143
3144
3145# All these operators share similar issues on cuDNN and MIOpen
3146rnn_gru_lstm_module_info_decorators = (
3147    # RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
3148    # We could not generate a fallback
3149    DecorateInfo(
3150        unittest.expectedFailure, "TestModule", "test_grad",
3151        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
3152    ),
3153    # NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented.
3154    # Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API
3155    DecorateInfo(
3156        unittest.expectedFailure, "TestModule", "test_gradgrad",
3157        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
3158    ),
3159    # CUDNN GRU doesn't accept non-contiguous hx
3160    DecorateInfo(
3161        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
3162        active_if=(TEST_CUDNN and not TEST_WITH_ROCM), device_type='cuda'
3163    ),
3164    # MIOPEN GRU doesn't accept non-contiguous hx (this is dispatched to miopen only for float).
3165    DecorateInfo(
3166        unittest.expectedFailure, "TestModule", "test_non_contiguous_tensors",
3167        active_if=(TEST_CUDNN and TEST_WITH_ROCM), dtypes=(torch.float,), device_type='cuda'
3168    ),
3169    DecorateInfo(
3170        skipCUDAVersionIn([(11, 7)]), "TestExpandedWeightModule", "test_module",
3171        device_type='cuda'
3172    ),
3173    DecorateInfo(
3174        skipCUDAVersionIn([(11, 7)]), "TestDecomp", "test_rnn_decomp_module",
3175        device_type='cuda'
3176    )
3177)
3178
3179# Start of module error inputs functions.
3180
3181def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
3182    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3183    samples = [
3184        ErrorModuleInput(
3185            ModuleInput(
3186                constructor_input=FunctionInput(10, 20),
3187                forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)),
3188            ),
3189            error_on=ModuleErrorEnum.FORWARD_ERROR,
3190            error_type=RuntimeError,
3191            error_regex="input has inconsistent input_size: got 11 expected 10"
3192        ),
3193        ErrorModuleInput(
3194            ModuleInput(
3195                constructor_input=FunctionInput(10, 20),
3196                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
3197            ),
3198            error_on=ModuleErrorEnum.FORWARD_ERROR,
3199            error_type=RuntimeError,
3200            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
3201        ),
3202        ErrorModuleInput(
3203            ModuleInput(
3204                constructor_input=FunctionInput(10, 20),
3205                forward_input=FunctionInput(make_input(3, 10), make_input(5, 20)),
3206            ),
3207            error_on=ModuleErrorEnum.FORWARD_ERROR,
3208            error_type=RuntimeError,
3209            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
3210        ),
3211        ErrorModuleInput(
3212            ModuleInput(
3213                constructor_input=FunctionInput(10, 20),
3214                forward_input=FunctionInput(make_input(3, 10), make_input(3, 1, 1, 20)),
3215            ),
3216            error_on=ModuleErrorEnum.FORWARD_ERROR,
3217            error_type=ValueError,
3218            error_regex="Expected hidden to be 1D or 2D, got 4D instead"
3219        ),
3220        ErrorModuleInput(
3221            ModuleInput(
3222                constructor_input=FunctionInput(10, 20, 'relu'),
3223                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
3224            ),
3225            error_on=ModuleErrorEnum.FORWARD_ERROR,
3226            error_type=RuntimeError,
3227            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
3228        ),
3229        ErrorModuleInput(
3230            ModuleInput(
3231                constructor_input=FunctionInput(10, 20, 'tanh'),
3232                forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
3233            ),
3234            error_on=ModuleErrorEnum.FORWARD_ERROR,
3235            error_type=RuntimeError,
3236            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
3237        ),
3238    ]
3239    return samples
3240
3241def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, training, **kwargs):
3242    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3243    samples = [
3244        ErrorModuleInput(
3245            ModuleInput(
3246                constructor_input=FunctionInput(10, 20),
3247                forward_input=FunctionInput(make_input(3, 11), (make_input(3, 20), make_input(3, 20))),
3248            ),
3249            error_on=ModuleErrorEnum.FORWARD_ERROR,
3250            error_type=RuntimeError,
3251            error_regex="input has inconsistent input_size: got 11 expected 10"
3252        ),
3253        ErrorModuleInput(
3254            ModuleInput(
3255                constructor_input=FunctionInput(10, 20),
3256                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 21), make_input(3, 21))),
3257            ),
3258            error_on=ModuleErrorEnum.FORWARD_ERROR,
3259            error_type=RuntimeError,
3260            error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
3261        ),
3262        ErrorModuleInput(
3263            ModuleInput(
3264                constructor_input=FunctionInput(10, 20),
3265                forward_input=FunctionInput(make_input(3, 10), (make_input(5, 20), make_input(5, 20))),
3266            ),
3267            error_on=ModuleErrorEnum.FORWARD_ERROR,
3268            error_type=RuntimeError,
3269            error_regex="Input batch size 3 doesn't match hidden0 batch size 5"
3270        ),
3271        ErrorModuleInput(
3272            ModuleInput(
3273                constructor_input=FunctionInput(10, 20),
3274                forward_input=FunctionInput(make_input(3, 10), (make_input(3, 1, 1, 20), make_input(3, 1, 1, 20))),
3275            ),
3276            error_on=ModuleErrorEnum.FORWARD_ERROR,
3277            error_type=ValueError,
3278            error_regex="Expected hx\\[0\\] to be 1D or 2D, got 4D instead"
3279        ),
3280    ]
3281    return samples
3282
3283
3284def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
3285    samples = [
3286        ErrorModuleInput(
3287            ModuleInput(constructor_input=FunctionInput(10, 0, 1)),
3288            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
3289            error_type=ValueError,
3290            error_regex="hidden_size must be greater than zero"
3291        ),
3292        ErrorModuleInput(
3293            ModuleInput(constructor_input=FunctionInput(10, 10, 0)),
3294            error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
3295            error_type=ValueError,
3296            error_regex="num_layers must be greater than zero"
3297        ),
3298    ]
3299    return samples
3300
3301def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs):
3302    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3303
3304    is_constant = kwargs.get('is_constant', False)
3305
3306    return [
3307        ErrorModuleInput(
3308            ModuleInput(
3309                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
3310                forward_input=FunctionInput(make_input((2, 3, 4, 5))),
3311            ),
3312            error_on=ModuleErrorEnum.FORWARD_ERROR,
3313            error_type=ValueError,
3314            error_regex=r"expected 2D or 3D input \(got 4D input\)",
3315
3316        ),
3317    ]
3318
3319def module_error_inputs_torch_nn_Pad2d(module_info, device, dtype, requires_grad, training, **kwargs):
3320    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3321
3322    is_constant = kwargs.get('is_constant', False)
3323
3324    return [
3325        ErrorModuleInput(
3326            ModuleInput(
3327                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
3328                forward_input=FunctionInput(make_input((2, 3))),
3329            ),
3330            error_on=ModuleErrorEnum.FORWARD_ERROR,
3331            error_type=ValueError,
3332            error_regex=r"expected 3D or 4D input \(got 2D input\)",
3333
3334        ),
3335    ]
3336
3337def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad, training, **kwargs):
3338    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3339
3340    is_constant = kwargs.get('is_constant', False)
3341
3342    return [
3343        ErrorModuleInput(
3344            ModuleInput(
3345                constructor_input=FunctionInput(1, 3) if is_constant else FunctionInput(3),
3346                forward_input=FunctionInput(make_input((2, 3))),
3347            ),
3348            error_on=ModuleErrorEnum.FORWARD_ERROR,
3349            error_type=ValueError,
3350            error_regex=r"expected 4D or 5D input \(got 2D input\)",
3351
3352        ),
3353    ]
3354
3355
3356_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0)
3357
3358
3359# Database of ModuleInfo entries in alphabetical order.
3360module_db: List[ModuleInfo] = [
3361    ModuleInfo(torch.nn.AdaptiveAvgPool1d,
3362               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
3363               skips=(
3364                   # Fails on MPS backend if input/output sizes are not divisible
3365                   DecorateInfo(skipMPS),)
3366               ),
3367    ModuleInfo(torch.nn.AdaptiveAvgPool2d,
3368               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3369               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d,
3370               skips=(
3371                   # Fails on MPS backend if input/output sizes are not divisible
3372                   DecorateInfo(skipMPS),
3373                   # Fails on backward check if output size is 1x1
3374                   DecorateInfo(
3375                       unittest.expectedFailure,
3376                       'TestModule',
3377                       'test_memory_format',
3378                       active_if=operator.itemgetter('training'),
3379                   ),)
3380               ),
3381    ModuleInfo(torch.nn.AdaptiveAvgPool3d,
3382               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3383               module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d,
3384               skips=(
3385                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3386                   # not supported on MPS backend
3387                   DecorateInfo(skipMPS),)
3388               ),
3389    ModuleInfo(torch.nn.AdaptiveMaxPool1d,
3390               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d,
3391               ),
3392    ModuleInfo(torch.nn.AdaptiveMaxPool2d,
3393               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3394               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool2d,
3395               ),
3396    ModuleInfo(torch.nn.AdaptiveMaxPool3d,
3397               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3398               module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d,
3399               skips=(
3400                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3401                   # not supported on MPS backend
3402                   DecorateInfo(skipMPS),)
3403               ),
3404    ModuleInfo(torch.nn.AvgPool1d,
3405               module_inputs_func=module_inputs_torch_nn_AvgPool1d,
3406               ),
3407    ModuleInfo(torch.nn.AvgPool2d,
3408               module_inputs_func=module_inputs_torch_nn_AvgPool2d,
3409               skips=(
3410                   # The difference between channels last backward and
3411                   # channels first backward of AvgPool2d on CUDA is too large
3412                   # See https://github.com/pytorch/pytorch/issues/107201
3413                   DecorateInfo(
3414                       unittest.expectedFailure,
3415                       'TestModule',
3416                       'test_memory_format',
3417                       active_if=operator.itemgetter('training'),
3418                       device_type='cuda',
3419                   ),
3420                   # error: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible
3421                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float16]),),
3422               ),
3423    ModuleInfo(torch.nn.AvgPool3d,
3424               module_inputs_func=module_inputs_torch_nn_AvgPool3d,
3425               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3426               skips=(
3427                   # No channels_last support for AvgPool1d as it does not take 4D inputs
3428                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3429                   # not supported on MPS backend
3430                   DecorateInfo(skipMPS),)
3431               ),
3432    ModuleInfo(torch.nn.BatchNorm1d,
3433               train_and_eval_differ=True,
3434               module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
3435               skips=(
3436                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
3437                   # RuntimeError: tried to get Double out of SymInt
3438                   DecorateInfo(
3439                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
3440                       'test_aot_autograd_symbolic_module_exhaustive',
3441                       active_if=operator.itemgetter('training')
3442                   ),
3443                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
3444                   DecorateInfo(
3445                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
3446                       'test_aot_autograd_module_exhaustive',
3447                       active_if=operator.itemgetter('training')
3448                   ))
3449               ),
3450    ModuleInfo(torch.nn.BatchNorm2d,
3451               train_and_eval_differ=True,
3452               module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
3453               skips=(
3454                   # See https://github.com/pytorch/pytorch/issues/134580
3455                   DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')),
3456                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
3457                   # RuntimeError: tried to get Double out of SymInt
3458                   DecorateInfo(
3459                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
3460                       'test_aot_autograd_symbolic_module_exhaustive',
3461                       active_if=operator.itemgetter('training')
3462                   ),
3463                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
3464                   DecorateInfo(
3465                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
3466                       'test_aot_autograd_module_exhaustive',
3467                       active_if=operator.itemgetter('training')
3468                   ),)
3469               ),
3470    ModuleInfo(torch.nn.BatchNorm3d,
3471               train_and_eval_differ=True,
3472               module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
3473               skips=(
3474                   # not supported on MPS backend
3475                   DecorateInfo(skipMPS),
3476                   # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
3477                   # RuntimeError: tried to get Double out of SymInt
3478                   DecorateInfo(
3479                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
3480                       'test_aot_autograd_symbolic_module_exhaustive',
3481                       active_if=operator.itemgetter('training')
3482                   ),
3483                   # torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
3484                   DecorateInfo(
3485                       unittest.expectedFailure, 'TestEagerFusionModuleInfo',
3486                       'test_aot_autograd_module_exhaustive',
3487                       active_if=operator.itemgetter('training')
3488                   ),)
3489               ),
3490    ModuleInfo(torch.nn.CELU,
3491               module_inputs_func=module_inputs_torch_nn_CELU,
3492               # not MPS specific, will be xfailed for all devices in next PR
3493               skips=(
3494                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
3495                                device_type='mps', dtypes=[torch.float16]),)
3496               ),
3497    ModuleInfo(torch.nn.Conv1d,
3498               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False),
3499               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3500               module_memformat_affects_out=True,
3501               skips=(
3502                   # channels_last support on cuda requires cudnn >= 7603
3503                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3504                   # Failure on ROCM for float32 issue #70125
3505                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3506                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3507                   # xfail does not work due to Fatal Python error: Aborted
3508                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3509                                device_type='mps', dtypes=[torch.float16]),
3510                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3511                                device_type='mps', dtypes=[torch.float16]),
3512               ),
3513               decorators=(
3514                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3515               )),
3516    ModuleInfo(torch.nn.Conv2d,
3517               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False),
3518               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3519               module_memformat_affects_out=True,
3520               skips=(
3521                   # channels_last support on cuda requires cudnn >= 7603
3522                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3523                   # Failure on ROCM for float32 issue #70125
3524                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3525                   # This was wrongly being skipped before and needs investigation.
3526                   # See https://github.com/pytorch/pytorch/issues/80247
3527                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
3528                                device_type='cuda', dtypes=[torch.float64]),
3529                   # Fails with channels last test on MPS backend
3530                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
3531                                device_type='mps', dtypes=[torch.float32]),
3532                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3533                   # xfail does not work due to Fatal Python error: Aborted
3534                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3535                                device_type='mps', dtypes=[torch.float16]),
3536                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3537                                device_type='mps', dtypes=[torch.float16]),
3538               ),
3539               decorators=(
3540                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3541               )),
3542    ModuleInfo(torch.nn.Conv3d,
3543               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False),
3544               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3545               module_memformat_affects_out=True,
3546               skips=(
3547                   # channels_last support on cuda requires cudnn >= 8005
3548                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
3549                   # Failure on ROCM for float32 issue #70125
3550                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3551                   # Conv3d is not supported on MPS backend
3552                   DecorateInfo(skipMPS),
3553                   # This was wrongly being skipped before and needs investigation.
3554                   # See https://github.com/pytorch/pytorch/issues/80247
3555                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
3556               ),
3557               decorators=(
3558                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3559               )),
3560    ModuleInfo(torch.nn.ConvTranspose1d,
3561               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=False, transposed=True),
3562               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3563               module_memformat_affects_out=True,
3564               dtypes=floating_and_complex_types_and(torch.chalf),
3565               skips=(
3566                   # channels_last support on cuda requires cudnn >= 7603
3567                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3568                   # Failure on ROCM for float32 issue #70125
3569                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3570                   # Not implmented for chalf on CPU
3571                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
3572                                dtypes=(torch.chalf,), device_type='cuda'),
3573                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3574                   # xfail does not work due to Fatal Python error: Aborted
3575                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3576                                device_type='mps', dtypes=[torch.float16]),
3577                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3578                                device_type='mps', dtypes=[torch.float16]),),
3579               decorators=(
3580                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3581                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
3582               )),
3583    ModuleInfo(torch.nn.ConvTranspose2d,
3584               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=False, transposed=True),
3585               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3586               module_memformat_affects_out=True,
3587               dtypes=floating_and_complex_types_and(torch.chalf),
3588               skips=(
3589                   # channels_last support on cuda requires cudnn >= 7603
3590                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3591                   # Failure on ROCM for float32 issue #70125
3592                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3593                   # Fails on backward check because ViewAsRealBackward apply contiguous for grad
3594                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
3595                                dtypes=(torch.complex32, torch.complex64, torch.complex128)),
3596                   # This was wrongly being skipped before and needs investigation.
3597                   # See https://github.com/pytorch/pytorch/issues/80247
3598                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
3599                                dtypes=[torch.float64, torch.complex128]),
3600                   # Fails with channels last test on MPS backend
3601                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
3602                                device_type='mps', dtypes=[torch.float32]),
3603                   # Not implemented for chalf on CPU
3604                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
3605                                dtypes=(torch.chalf,), device_type='cuda'),
3606                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3607                   # xfail does not work due to Fatal Python error: Aborted
3608                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3609                                device_type='mps', dtypes=[torch.float16]),
3610                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3611                                device_type='mps', dtypes=[torch.float16]),
3612               ),
3613               decorators=(
3614                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3615                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
3616               )),
3617    ModuleInfo(torch.nn.ConvTranspose3d,
3618               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=False, transposed=True),
3619               dtypes=floating_and_complex_types_and(torch.chalf),
3620               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3621               module_memformat_affects_out=True,
3622               skips=(
3623                   # channels_last support on cuda requires cudnn >= 8005
3624                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
3625                   # Failure on ROCM for float32 issue #70125
3626                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3627                   # ConvTranspose3d is not supported on MPS backend
3628                   DecorateInfo(skipMPS),
3629                   # This was wrongly being skipped before and needs investigation.
3630                   # See https://github.com/pytorch/pytorch/issues/80247
3631                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
3632                   # These fail only on ROCm
3633                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
3634                                dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
3635                   # Not implmented for chalf on CPU
3636                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
3637                                dtypes=(torch.chalf,), device_type='cuda'),
3638               ),
3639               decorators=(
3640                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3641                   DecorateInfo(precisionOverride({torch.complex64: 1e-04}), 'TestModule', 'test_cpu_gpu_parity'),
3642                   DecorateInfo(precisionOverride({torch.chalf: 5e-03}), 'TestModule', 'test_memory_format'),
3643               )),
3644    ModuleInfo(torch.nn.CosineEmbeddingLoss,
3645               module_inputs_func=module_inputs_torch_nn_CosineEmbeddingLoss,
3646               skips=(
3647                   # No channels_last support for loss functions.
3648                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3649               ),
3650    ModuleInfo(torch.nn.ELU,
3651               module_inputs_func=module_inputs_torch_nn_ELU,
3652               # not MPS specific, will be xfailed for all devices in next PR
3653               skips=(
3654                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_check_inplace',
3655                                device_type='mps', dtypes=[torch.float16]),)
3656               ),
3657    ModuleInfo(torch.nn.FractionalMaxPool2d,
3658               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d,
3659               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3660               skips=(
3661                   # not supported on MPS backend
3662                   DecorateInfo(skipMPS),
3663                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3664               ),
3665    ModuleInfo(torch.nn.FractionalMaxPool3d,
3666               module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d,
3667               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3668               skips=(
3669                   # not supported on MPS backend
3670                   DecorateInfo(skipMPS),
3671                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3672               ),
3673    ModuleInfo(torch.nn.L1Loss,
3674               module_inputs_func=module_inputs_torch_nn_L1Loss,
3675               skips=(
3676                   # No channels_last support for loss functions.
3677                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3678               ),
3679    ModuleInfo(torch.nn.SmoothL1Loss,
3680               module_inputs_func=module_inputs_torch_nn_SmoothL1Loss,
3681               skips=(
3682                   # No channels_last support for loss functions.
3683                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3684                   # See #119108: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible
3685                   DecorateInfo(skipIfMps, 'TestModule', 'test_non_contiguous_tensors', dtypes=[torch.float16]),)
3686               ),
3687    ModuleInfo(torch.nn.LazyConv1d,
3688               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True),
3689               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3690               module_memformat_affects_out=True,
3691               skips=(
3692                   # channels_last support on cuda requires cudnn >= 7603
3693                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3694                   # Failure on ROCM for float32 issue #70125
3695                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3696                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
3697                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
3698                   DecorateInfo(skipMeta),
3699                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3700                   # xfail does not work due to Fatal Python error: Aborted
3701                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3702                                device_type='mps', dtypes=[torch.float16]),
3703                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3704                                device_type='mps', dtypes=[torch.float16]),
3705               ),
3706               decorators=(
3707                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3708               )),
3709    ModuleInfo(torch.nn.LazyConv2d,
3710               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True),
3711               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3712               module_memformat_affects_out=True,
3713               skips=(
3714                   # channels_last support on cuda requires cudnn >= 7603
3715                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3716                   # Failure on ROCM for float32 issue #70125
3717                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3718                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
3719                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
3720                   DecorateInfo(skipMeta),
3721                   # This was wrongly being skipped before and needs investigation.
3722                   # See https://github.com/pytorch/pytorch/issues/80247
3723                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
3724                                device_type='cuda', dtypes=[torch.float64]),
3725                   # Fails with channels last test on MPS backend
3726                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
3727                                device_type='mps', dtypes=[torch.float32]),
3728                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3729                   # xfail does not work due to Fatal Python error: Aborted
3730                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3731                                device_type='mps', dtypes=[torch.float16]),
3732                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3733                                device_type='mps', dtypes=[torch.float16]),
3734               ),
3735               decorators=(
3736                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3737               )),
3738    ModuleInfo(torch.nn.LazyConv3d,
3739               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True),
3740               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3741               module_memformat_affects_out=True,
3742               skips=(
3743                   # channels_last support on cuda requires cudnn >= 8005
3744                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
3745                   # Failure on ROCM for float32 issue #70125
3746                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3747                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
3748                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
3749                   DecorateInfo(skipMeta),
3750                   # LazyConv3d is not supported on MPS backend
3751                   DecorateInfo(skipMPS),
3752                   # This was wrongly being skipped before and needs investigation.
3753                   # See https://github.com/pytorch/pytorch/issues/80247
3754                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
3755               ),
3756               decorators=(
3757                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3758               )),
3759    ModuleInfo(torch.nn.LazyConvTranspose1d,
3760               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=1, lazy=True, transposed=True),
3761               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3762               module_memformat_affects_out=True,
3763               skips=(
3764                   # channels_last support on cuda requires cudnn >= 7603
3765                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3766                   # Failure on ROCM for float32 issue #70125
3767                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3768                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
3769                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
3770                   DecorateInfo(skipMeta),
3771                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3772                   # xfail does not work due to Fatal Python error: Aborted
3773                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3774                                device_type='mps', dtypes=[torch.float16]),
3775                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3776                                device_type='mps', dtypes=[torch.float16]),
3777               ),
3778               decorators=(
3779                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3780               )),
3781    ModuleInfo(torch.nn.LazyConvTranspose2d,
3782               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=2, lazy=True, transposed=True),
3783               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3784               module_memformat_affects_out=True,
3785               skips=(
3786                   # channels_last support on cuda requires cudnn >= 7603
3787                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
3788                   # Failure on ROCM for float32 issue #70125
3789                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3790                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
3791                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
3792                   DecorateInfo(skipMeta),
3793                   # This was wrongly being skipped before and needs investigation.
3794                   # See https://github.com/pytorch/pytorch/issues/80247
3795                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
3796                                dtypes=[torch.float64]),
3797                   # Fails with channels last test on MPS backend
3798                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
3799                                device_type='mps', dtypes=[torch.float32]),
3800                   # See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
3801                   # xfail does not work due to Fatal Python error: Aborted
3802                   DecorateInfo(skipIfMps, "TestModule", "test_memory_format",
3803                                device_type='mps', dtypes=[torch.float16]),
3804                   DecorateInfo(skipIfMps, "TestModule", "test_non_contiguous_tensors",
3805                                device_type='mps', dtypes=[torch.float16]),
3806               ),
3807               decorators=(
3808                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3809               )),
3810    ModuleInfo(torch.nn.LazyConvTranspose3d,
3811               module_inputs_func=partial(module_inputs_torch_nn_ConvNd, N=3, lazy=True, transposed=True),
3812               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3813               module_memformat_affects_out=True,
3814               skips=(
3815                   # channels_last support on cuda requires cudnn >= 8005
3816                   DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
3817                   # Failure on ROCM for float32 issue #70125
3818                   DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
3819                   # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
3820                   # See https://github.com/pytorch/pytorch/issues/70505 for more info.
3821                   DecorateInfo(skipMeta),
3822                   # LazyConvTranspose3d is not supported on MPS backend
3823                   DecorateInfo(skipMPS),
3824                   # This was wrongly being skipped before and needs investigation.
3825                   # See https://github.com/pytorch/pytorch/issues/80247
3826                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
3827               ),
3828               decorators=(
3829                   DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
3830               )),
3831    ModuleInfo(torch.nn.Linear,
3832               module_inputs_func=module_inputs_torch_nn_Linear,
3833               skips=(
3834                   # No channels_last support for Linear currently.
3835                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3836               ),
3837    ModuleInfo(torch.nn.Bilinear,
3838               module_inputs_func=module_inputs_torch_nn_Bilinear,
3839               decorators=[
3840                   DecorateInfo(
3841                       toleranceOverride({
3842                           torch.float32: tol(atol=1e-4, rtol=1e-4),
3843                           torch.float64: tol(atol=1e-4, rtol=1e-4)}),
3844                       'TestModule', 'test_forward', device_type='cpu'),
3845               ],
3846               skips=(
3847                   # No channels_last support for Bilinear currently.
3848                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3849                   # See #119108: tolerance issue
3850                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
3851                                device_type='mps', dtypes=[torch.float16]),)
3852               ),
3853    ModuleInfo(torch.nn.LPPool1d,
3854               module_inputs_func=module_inputs_torch_nn_LPPool1d,
3855               skips=(
3856                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
3857                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
3858               ),
3859    ModuleInfo(torch.nn.LPPool2d,
3860               module_inputs_func=module_inputs_torch_nn_LPPool2d,
3861               skips=(
3862                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
3863                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
3864                   # Fails on backward check on MPS
3865                   # See https://github.com/pytorch/pytorch/issues/107214
3866                   DecorateInfo(
3867                       unittest.expectedFailure,
3868                       'TestModule',
3869                       'test_memory_format',
3870                       active_if=operator.itemgetter('training'),
3871                       device_type='mps',
3872                   ),)
3873               ),
3874    ModuleInfo(torch.nn.LPPool3d,
3875               module_inputs_func=module_inputs_torch_nn_LPPool3d,
3876               skips=(
3877                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
3878                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
3879                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3880                   DecorateInfo(skipIfMps),)
3881               ),
3882    ModuleInfo(torch.nn.MaxPool1d,
3883               module_inputs_func=module_inputs_torch_nn_MaxPool1d,
3884               ),
3885    ModuleInfo(torch.nn.MaxPool2d,
3886               module_inputs_func=module_inputs_torch_nn_MaxPool2d,
3887               ),
3888    ModuleInfo(torch.nn.MaxPool3d,
3889               module_inputs_func=module_inputs_torch_nn_MaxPool3d,
3890               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
3891               skips=(
3892                   # not supported on MPS backend
3893                   DecorateInfo(skipMPS),)
3894               ),
3895    ModuleInfo(torch.nn.KLDivLoss,
3896               module_inputs_func=module_inputs_torch_nn_KLDivLoss,
3897               skips=(
3898                   # No channels_last support for loss functions.
3899                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3900                   # https://github.com/pytorch/pytorch/issues/115588
3901                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
3902                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
3903                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
3904               ),
3905    ModuleInfo(torch.nn.MSELoss,
3906               module_inputs_func=module_inputs_torch_nn_MSELoss,
3907               skips=(
3908                   # No channels_last support for loss functions.
3909                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3910                   # See #119108: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible
3911                   DecorateInfo(skipIfMps, 'TestModule', 'test_non_contiguous_tensors', dtypes=[torch.float16]),
3912                   # See #119108: tolerance issue
3913                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
3914                                device_type='mps', dtypes=[torch.float16]),)
3915               ),
3916    ModuleInfo(torch.nn.MarginRankingLoss,
3917               module_inputs_func=module_inputs_torch_nn_MarginRankingLoss,
3918               skips=(
3919                   # No channels_last support for loss functions.
3920                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3921               ),
3922    ModuleInfo(torch.nn.MultiLabelMarginLoss,
3923               module_inputs_func=module_inputs_torch_nn_MultiLabelMarginLoss,
3924               skips=(
3925                   # No channels_last support for loss functions.
3926                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3927                   # 'aten::multilabel_margin_loss_forward' is not currently implemented for the MPS device.
3928                   DecorateInfo(skipIfMps, 'TestModule'),
3929                   # derivative for aten::multilabel_margin_loss_backward is not implemented
3930                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
3931               ),
3932    ModuleInfo(torch.nn.MultiMarginLoss,
3933               module_inputs_func=module_inputs_torch_nn_MultiMarginLoss,
3934               skips=(
3935                   # No channels_last support for loss functions.
3936                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3937                   # 'aten::multi_margin_loss' is not currently implemented for the MPS device.
3938                   DecorateInfo(skipIfMps, 'TestModule'),
3939                   # RuntimeError: derivative for aten::multi_margin_loss_backward is not implemented
3940                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),)
3941               ),
3942    ModuleInfo(torch.nn.SoftMarginLoss,
3943               module_inputs_func=module_inputs_torch_nn_SoftMarginLoss,
3944               skips=(
3945                   # No channels_last support for loss functions.
3946                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3947                   # See #119108: tolerance issue
3948                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
3949                                device_type='mps', dtypes=[torch.float16]),)
3950               ),
3951    ModuleInfo(torch.nn.MultiLabelSoftMarginLoss,
3952               module_inputs_func=module_inputs_torch_nn_MultiLabelSoftMarginLoss,
3953               skips=(
3954                   # No channels_last support for loss functions.
3955                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3956               ),
3957    ModuleInfo(torch.nn.NLLLoss,
3958               module_inputs_func=module_inputs_torch_nn_NLLLoss,
3959               skips=(
3960                   # No channels_last support for loss functions.
3961                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3962                   # See #119108: tolerance issue
3963                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
3964                                device_type='mps', dtypes=[torch.float16]),)
3965               ),
3966    ModuleInfo(torch.nn.GaussianNLLLoss,
3967               module_inputs_func=module_inputs_torch_nn_GaussianNLLLoss,
3968               skips=(
3969                   # No channels_last support for loss functions.
3970                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
3971    ModuleInfo(torch.nn.PoissonNLLLoss,
3972               module_inputs_func=module_inputs_torch_nn_PoissonNLLLoss,
3973               skips=(
3974                   # No channels_last support for loss functions.
3975                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)),
3976    ModuleInfo(torch.nn.HingeEmbeddingLoss,
3977               module_inputs_func=module_inputs_torch_nn_HingeEmbeddingLoss,
3978               skips=(
3979                   # No channels_last support for loss functions.
3980                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
3981               ),
3982    ModuleInfo(torch.nn.HuberLoss,
3983               module_inputs_func=module_inputs_torch_nn_HuberLoss,
3984               skips=(
3985                   # No channels_last support for loss functions.
3986                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3987                   # See #119108: seemingly incorrect output dtype
3988                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
3989                                device_type='mps', dtypes=[torch.float16]),)
3990               ),
3991    ModuleInfo(torch.nn.BCELoss,
3992               module_inputs_func=module_inputs_torch_nn_BCELoss,
3993               skips=(
3994                   # No channels_last support for loss functions.
3995                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
3996                   # error: input types 'tensor<f32>' and 'tensor<15x10xf16>' are not broadcast compatible
3997                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float16]),)
3998               ),
3999    ModuleInfo(torch.nn.BCEWithLogitsLoss,
4000               module_inputs_func=module_inputs_torch_nn_BCEWithLogitsLoss,
4001               skips=(
4002                   # No channels_last support for loss functions.
4003                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
4004                   # see #119108: tolerance issue
4005                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float16]),)
4006               ),
4007    ModuleInfo(torch.nn.CrossEntropyLoss,
4008               module_inputs_func=module_inputs_torch_nn_CrossEntropyLoss,
4009               dtypes=get_all_fp_dtypes(include_half=True, include_bfloat16=False),
4010               decorators=(
4011                   # No channels_last support for loss functions.
4012                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format'),
4013                   DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-2, rtol=1e-3)}), "TestModule",
4014                                "test_forward", dtypes=[torch.float16], device_type='cpu'),
4015                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_cpu_gpu_parity", dtypes=[torch.float16],
4016                                device_type='cuda'),),
4017               ),
4018    ModuleInfo(torch.nn.CTCLoss,
4019               module_inputs_func=module_inputs_torch_nn_CTCLoss,
4020               skips=(
4021                   # No channels_last support for loss functions.
4022                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
4023                   # The operator aten::_ctc_loss is not currently implemented for the MPS device.
4024                   DecorateInfo(skipIfMps, 'TestModule'),
4025                   # derivative for aten::_ctc_loss_backward is not implemented
4026                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_grad'),
4027                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad'),
4028                   # https://github.com/pytorch/pytorch/issues/115585
4029                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_non_contiguous_tensors'),)
4030               ),
4031    ModuleInfo(torch.nn.GELU,
4032               module_inputs_func=module_inputs_torch_nn_GELU,
4033               skips=(
4034                   # See #119108: tolerance issue
4035                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward",
4036                                device_type='mps', dtypes=[torch.float16]),)
4037               ),
4038    ModuleInfo(torch.nn.GLU,
4039               module_inputs_func=module_inputs_torch_nn_GLU,
4040               ),
4041    ModuleInfo(torch.nn.GroupNorm,
4042               module_inputs_func=module_inputs_torch_nn_GroupNorm,
4043               dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True),
4044               skips=(
4045                   # Tracking at https://github.com/pytorch/pytorch/issues/98089
4046                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
4047                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
4048                                'TestModule', 'test_memory_format', device_type='cpu'),
4049                   # No channels_last support for GroupNorm currently.
4050                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='cuda'),
4051                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format', device_type='mps'),
4052                   DecorateInfo(unittest.skip("Skipped!"), "TestModule", "test_grad",
4053                                active_if=TEST_WITH_ROCM, device_type='cuda'),)
4054               ),
4055    ModuleInfo(torch.nn.Hardshrink,
4056               module_inputs_func=module_inputs_torch_nn_Hardshrink,
4057               skips=(
4058                   # not supported on MPS backend
4059                   DecorateInfo(skipMPS),),
4060               ),
4061    ModuleInfo(torch.nn.Hardswish,
4062               module_inputs_func=module_inputs_torch_nn_Hardswish,
4063               skips=None if _macos15_or_newer else (
4064                   # Fails on backward check on MPS
4065                   # See https://github.com/pytorch/pytorch/issues/107214
4066                   DecorateInfo(
4067                       unittest.expectedFailure,
4068                       'TestModule',
4069                       'test_memory_format',
4070                       active_if=operator.itemgetter('training'),
4071                       device_type='mps',
4072                   ),),
4073               supports_gradgrad=False),
4074    ModuleInfo(torch.nn.Hardtanh,
4075               module_inputs_func=module_inputs_torch_nn_Hardtanh,
4076               ),
4077    ModuleInfo(torch.nn.InstanceNorm1d,
4078               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=1),
4079               train_and_eval_differ=True,
4080               skips=(
4081                   # No channels_last support for InstanceNorm1d currently.
4082                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4083               ),
4084    ModuleInfo(torch.nn.InstanceNorm2d,
4085               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=2),
4086               train_and_eval_differ=True,
4087               skips=(
4088                   # No channels_last support for InstanceNorm2d currently.
4089                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4090               ),
4091    ModuleInfo(torch.nn.InstanceNorm3d,
4092               module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3),
4093               train_and_eval_differ=True,
4094               skips=(
4095                   # not supported on MPS backend
4096                   DecorateInfo(skipMPS),
4097                   # No channels_last support for InstanceNorm3d currently.
4098                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4099               ),
4100    ModuleInfo(torch.nn.LocalResponseNorm,
4101               module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
4102               skips=(
4103                   # uses avg_pool3d which is not supported on MPS backend
4104                   DecorateInfo(skipMPS),)
4105               ),
4106    ModuleInfo(torch.nn.LayerNorm,
4107               module_inputs_func=module_inputs_torch_nn_LayerNorm,
4108               skips=(
4109                   # No channels_last support for LayerNorm currently.
4110                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4111               ),
4112    ModuleInfo(torch.nn.RMSNorm,
4113               module_inputs_func=module_inputs_torch_nn_RMSNorm,
4114               ),
4115    # TransformerEncoder takes the same inputs as TransformerEncoderLayer
4116    ModuleInfo(torch.nn.TransformerEncoder,
4117               train_and_eval_differ=True,
4118               module_inputs_func=module_inputs_torch_nn_TransformerEncoder,
4119               decorators=[
4120                   # Not implemented for SDPA backward derivative
4121                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
4122                                device_type='cpu'),
4123               ],
4124               skips=(
4125                   # No channels_last support for TransformerEncoderLayer currently.
4126                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
4127                   # Doesn't support device / dtype kwargs directly because it is just a
4128                   # container of TransformerEncoderLayers.
4129                   DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_factory_kwargs'),)
4130               ),
4131    ModuleInfo(torch.nn.TransformerEncoderLayer,
4132               train_and_eval_differ=True,
4133               module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
4134               decorators=[
4135                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
4136                                'TestModule', 'test_non_contiguous_tensors',
4137                                device_type='cpu', active_if=IS_WINDOWS),
4138                   # Not implemented for SDPA backward derivative
4139                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
4140                                device_type='cpu'),
4141               ],
4142               skips=(
4143                   # No channels_last support for TransformerEncoderLayer currently.
4144                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4145               ),
4146    ModuleInfo(torch.nn.TransformerDecoderLayer,
4147               module_inputs_func=module_inputs_torch_nn_TransformerDecoderLayer,
4148               decorators=[
4149                   # Not implemented for SDPA backward derivative
4150                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
4151                                device_type='cpu'),
4152               ],
4153               skips=(
4154                   # No channels_last support for TransformerDecoderLayer currently.
4155                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4156               ),
4157    ModuleInfo(torch.nn.Transformer,
4158               module_inputs_func=module_inputs_torch_nn_Transformer,
4159               decorators=[
4160                   # Not implemented for SDPA backward derivative
4161                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',
4162                                device_type='cpu'),
4163               ],
4164               skips=(
4165                   # No channels_last support for Transformer currently.
4166                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4167               ),
4168    ModuleInfo(torch.nn.MultiheadAttention,
4169               train_and_eval_differ=True,
4170               module_inputs_func=module_inputs_torch_nn_MultiheadAttention,
4171               skips=(
4172                   # No channels_last support for MultiheadAttention currently.
4173                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4174               ),
4175    ModuleInfo(torch.nn.Embedding,
4176               module_inputs_func=module_inputs_torch_nn_Embedding,
4177               decorators=[
4178                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
4179                                'TestModule', 'test_non_contiguous_tensors',
4180                                device_type='mps')],
4181               skips=(
4182                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4183               ),
4184    ModuleInfo(torch.nn.ReLU,
4185               module_inputs_func=module_inputs_torch_nn_ReLU,
4186               skips=None if _macos15_or_newer else (
4187                   # Fails on backward check on MPS
4188                   # See https://github.com/pytorch/pytorch/issues/107214
4189                   DecorateInfo(
4190                       unittest.expectedFailure,
4191                       'TestModule',
4192                       'test_memory_format',
4193                       active_if=operator.itemgetter('training'),
4194                       device_type='mps',
4195                   ),)
4196               ),
4197    ModuleInfo(torch.nn.LeakyReLU,
4198               module_inputs_func=module_inputs_torch_nn_LeakyReLU,
4199               ),
4200    ModuleInfo(torch.nn.ReLU6,
4201               module_inputs_func=module_inputs_torch_nn_ReLU6,
4202               skips=(
4203                   # test fails on MPS backend and is being investigated.
4204                   # See https://github.com/pytorch/pytorch/issues/100914
4205                   DecorateInfo(skipMPS),)
4206               ),
4207    ModuleInfo(torch.nn.PReLU,
4208               module_inputs_func=module_inputs_torch_nn_PReLU,
4209               skips=(
4210                   # test fails on MPS backend and is being investigated.
4211                   # See https://github.com/pytorch/pytorch/issues/100914
4212                   DecorateInfo(skipMPS),)
4213               ),
4214    ModuleInfo(torch.nn.RNNCell,
4215               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
4216               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
4217               ),
4218    ModuleInfo(torch.nn.GRUCell,
4219               module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell,
4220               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
4221               ),
4222    ModuleInfo(torch.nn.LSTMCell,
4223               module_inputs_func=module_inputs_torch_nn_LSTMCell,
4224               module_error_inputs_func=module_error_inputs_torch_nn_LSTMCell,
4225               ),
4226    ModuleInfo(torch.nn.Sigmoid,
4227               module_inputs_func=module_inputs_torch_nn_Sigmoid,
4228               skips=None if _macos15_or_newer else (
4229                   # Fails on backward check on MPS
4230                   # See https://github.com/pytorch/pytorch/issues/107214
4231                   DecorateInfo(
4232                       unittest.expectedFailure,
4233                       'TestModule',
4234                       'test_memory_format',
4235                       active_if=operator.itemgetter('training'),
4236                       device_type='mps',
4237                   ),)
4238               ),
4239    ModuleInfo(torch.nn.LogSigmoid,
4240               module_inputs_func=module_inputs_torch_nn_LogSigmoid,
4241               skips=(
4242                   # See #119108: tolerance issue
4243                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
4244               ),
4245    ModuleInfo(torch.nn.SiLU,
4246               module_inputs_func=module_inputs_torch_nn_SiLU,
4247               ),
4248    ModuleInfo(torch.nn.Softmax,
4249               module_inputs_func=module_inputs_torch_nn_Softmax,
4250               ),
4251    ModuleInfo(torch.nn.Softmax2d,
4252               module_inputs_func=module_inputs_torch_nn_Softmax2d,
4253               skips=(
4254                   # no channels last support for Softmax2d currently
4255                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
4256                   # See #119108: tolerance issue
4257                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
4258               ),
4259    ModuleInfo(torch.nn.LogSoftmax,
4260               module_inputs_func=module_inputs_torch_nn_LogSoftmax,
4261               skips=(
4262                   # no channels last support for LogSoftmax currently
4263                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
4264                   # See #119108: inf nan error
4265                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_forward", device_type='mps', dtypes=[torch.float16]),)
4266               ),
4267    ModuleInfo(torch.nn.Softmin,
4268               module_inputs_func=module_inputs_torch_nn_Softmin,
4269               skips=(
4270                   # no channels last support for Softmin currently
4271                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
4272               ),
4273    ModuleInfo(torch.nn.Softplus,
4274               module_inputs_func=module_inputs_torch_nn_Softplus,
4275               skips=(
4276                   # test fails on MPS backend and is being investigated.
4277                   # See https://github.com/pytorch/pytorch/issues/100914
4278                   DecorateInfo(skipMPS),)
4279               ),
4280    ModuleInfo(torch.nn.Softshrink,
4281               module_inputs_func=module_inputs_torch_nn_Softshrink,
4282               skips=(
4283                   # not supported on MPS backend
4284                   DecorateInfo(skipMPS),)
4285               ),
4286    ModuleInfo(torch.nn.Softsign,
4287               module_inputs_func=module_inputs_torch_nn_Softsign,
4288               ),
4289    ModuleInfo(torch.nn.Tanh,
4290               module_inputs_func=module_inputs_torch_nn_Tanh,
4291               skips=None if _macos15_or_newer else (
4292                   # Fails on backward check on MPS
4293                   # See https://github.com/pytorch/pytorch/issues/107214
4294                   DecorateInfo(
4295                       unittest.expectedFailure,
4296                       'TestModule',
4297                       'test_memory_format',
4298                       active_if=operator.itemgetter('training'),
4299                       device_type='mps',
4300                   ),)
4301               ),
4302    ModuleInfo(torch.nn.Tanhshrink,
4303               module_inputs_func=module_inputs_torch_nn_Tanhshrink,
4304               skips=None if _macos15_or_newer else (
4305                   # Fails on backward check on MPS
4306                   # See https://github.com/pytorch/pytorch/issues/107214
4307                   DecorateInfo(
4308                       unittest.expectedFailure,
4309                       'TestModule',
4310                       'test_memory_format',
4311                       active_if=operator.itemgetter('training'),
4312                       device_type='mps',
4313                   ),)
4314               ),
4315    ModuleInfo(torch.nn.Threshold,
4316               module_inputs_func=module_inputs_torch_nn_Threshold,
4317               skips=(
4318                   # test fails on MPS backend and is being investigated.
4319                   # See https://github.com/pytorch/pytorch/issues/100914
4320                   DecorateInfo(skipMPS),)
4321               ),
4322    ModuleInfo(torch.nn.Mish,
4323               module_inputs_func=module_inputs_torch_nn_Mish,
4324               skips=(
4325                   # not supported on MPS backend
4326                   DecorateInfo(skipMPS),)
4327               ),
4328    ModuleInfo(torch.nn.RNN,
4329               train_and_eval_differ=True,
4330               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
4331               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
4332               decorators=rnn_gru_lstm_module_info_decorators
4333               ),
4334    ModuleInfo(torch.nn.GRU,
4335               train_and_eval_differ=True,
4336               module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
4337               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
4338               decorators=rnn_gru_lstm_module_info_decorators),
4339    ModuleInfo(torch.nn.LSTM,
4340               train_and_eval_differ=True,
4341               module_inputs_func=module_inputs_torch_nn_LSTM,
4342               module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
4343               skips=(
4344                   # LSTM with projections is not currently supported with MPS
4345                   DecorateInfo(skipMPS),),
4346               decorators=rnn_gru_lstm_module_info_decorators),
4347    ModuleInfo(torch.nn.ReflectionPad1d,
4348               module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
4349               ),
4350    ModuleInfo(torch.nn.ReflectionPad2d,
4351               module_inputs_func=module_inputs_torch_nn_ReflectionPad2d,
4352               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
4353               skips=(
4354                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4355                                device_type='cuda'),
4356                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4357                                device_type='mps'),)
4358               ),
4359    ModuleInfo(torch.nn.ReflectionPad3d,
4360               module_inputs_func=module_inputs_torch_nn_ReflectionPad3d,
4361               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
4362               skips=(
4363                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4364                                device_type='cuda'),
4365                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4366                                device_type='mps'),)
4367               ),
4368    ModuleInfo(torch.nn.ReplicationPad1d,
4369               module_inputs_func=module_inputs_torch_nn_ReplicationPad1d,
4370               ),
4371    ModuleInfo(torch.nn.ReplicationPad2d,
4372               module_inputs_func=module_inputs_torch_nn_ReplicationPad2d,
4373               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
4374               skips=(
4375                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4376                                device_type='cuda'),
4377                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4378                                device_type='mps'),)
4379               ),
4380    ModuleInfo(torch.nn.ReplicationPad3d,
4381               module_inputs_func=module_inputs_torch_nn_ReplicationPad3d,
4382               gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
4383               skips=(
4384                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4385                                device_type='cuda'),
4386                   DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format',
4387                                device_type='mps'),)
4388               ),
4389    ModuleInfo(torch.nn.SELU,
4390               module_inputs_func=module_inputs_torch_nn_SELU,
4391               skips=(
4392                   # test fails on MPS backend and is being investigated.
4393                   # See https://github.com/pytorch/pytorch/issues/100914
4394                   DecorateInfo(skipMPS),)
4395               ),
4396    ModuleInfo(torch.nn.ZeroPad1d,
4397               module_inputs_func=module_inputs_torch_nn_ZeroPad1d,
4398               ),
4399    ModuleInfo(torch.nn.ZeroPad2d,
4400               module_inputs_func=module_inputs_torch_nn_ZeroPad2d,
4401               skips=(
4402                   # Fails with channels last test on MPS backend
4403                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
4404               ),
4405    ModuleInfo(torch.nn.ZeroPad3d,
4406               module_inputs_func=module_inputs_torch_nn_ZeroPad3d,
4407               skips=(
4408                   # Fails with channels last test on MPS backend
4409                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
4410               ),
4411    ModuleInfo(torch.nn.CircularPad1d,
4412               module_inputs_func=module_inputs_torch_nn_CircularPad1d,
4413               module_error_inputs_func=module_error_inputs_torch_nn_Pad1d,
4414               ),
4415    ModuleInfo(torch.nn.CircularPad2d,
4416               module_inputs_func=module_inputs_torch_nn_CircularPad2d,
4417               module_error_inputs_func=module_error_inputs_torch_nn_Pad2d,
4418               ),
4419    ModuleInfo(torch.nn.CircularPad3d,
4420               module_inputs_func=module_inputs_torch_nn_CircularPad3d,
4421               module_error_inputs_func=module_error_inputs_torch_nn_Pad3d,
4422               skips=(
4423                   # Fails with channels last test on MPS backend
4424                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),)
4425               ),
4426    ModuleInfo(torch.nn.ConstantPad1d,
4427               module_inputs_func=module_inputs_torch_nn_ConstantPad1d,
4428               ),
4429    ModuleInfo(torch.nn.ConstantPad2d,
4430               module_inputs_func=module_inputs_torch_nn_ConstantPad2d,
4431               skips=(
4432                   # Fails with channels last test on MPS backend
4433                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
4434               ),
4435    ModuleInfo(torch.nn.ConstantPad3d,
4436               module_inputs_func=module_inputs_torch_nn_ConstantPad3d,
4437               skips=(
4438                   # Fails with channels last test on MPS backend
4439                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
4440               )
4441]
4442