xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_nn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from abc import abstractmethod
4import tempfile
5import unittest
6
7from copy import deepcopy
8from functools import reduce, partial
9from itertools import product
10from operator import mul
11
12
13import torch
14import torch.cuda
15import torch.nn as nn
16import torch.nn.functional as F
17from torch.nn import _reduction as _Reduction
18from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
19    gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo
20from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
21from torch.autograd.gradcheck import _get_numerical_jacobian, _iter_tensors
22from torch.autograd import Variable
23from torch.types import _TensorOrTensors
24import torch.backends.cudnn
25
26from typing import Dict, Callable, Tuple, List, Sequence, Union, Any
27
28TemporaryFile = tempfile.TemporaryFile
29PRECISION = 1e-5
30
31
32def get_reduction(m):
33    result = getattr(m, 'reduction', None)
34    if result is None:
35        result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
36    assert result is not None
37    return result
38
39
40def get_weight(m):
41    result = getattr(m, 'weight', None)
42    if result is not None:
43        return result
44    return getattr(m, 'weights', None)
45
46# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
47#
48# The way to check API parity is to add parity tests for the NN module / functional of interest.
49# Here are the detailed steps:
50#
51# For NN module:
52# 1. Make sure you already have a test dict with the module configuration you want to test.
53# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
54#    the Python module constructor arguments. For example, if in the test dict we pass
55#    `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
56#    as the corresponding C++ constructor argument to `torch::nn::Linear`.
57# 3. If in the process of performing the above step you referenced any variables
58#    in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
59#    to the test dict to make sure that those variables are populated with the right Python values.
60#    For example, if the Python constructor call is
61#    `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
62#    the corresponding C++ constructor argument is
63#    `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
64#    and the `cpp_var_map` entry must be
65#    `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
66#    used in the C++ constructor argument with the Python tensor value `random_samples`.
67#
68# For NN functional:
69# 1. Make sure you already have a test dict with the functional configuration you want to test.
70# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
71#    then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
72#    functional optional arguments. For example, if the test dict's `constructor` entry is
73#    `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
74#    then the `cpp_options_args` entry should be
75#    "F::InterpolateFuncOptions().size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)".
76# 3. Otherwise, if the test dict's `constructor` entry looks like
77#    `wrap_functional(lambda i: F.some_functional_name(...))`,
78#    then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
79#    functional function call. For example, if the test dict's `constructor` entry is
80#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
81#    then the `cpp_function_call` entry should be
82#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
83# 4. If in the process of performing the above two steps you referenced any variables
84#    in the `cpp_options_args` or `cpp_function_call` entry, you must
85#    add `cpp_var_map` entry to the test dict to make sure that those variables
86#    are populated with the right Python values. For example, if the test dict's `constructor` entry is
87#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
88#    then the `cpp_function_call` entry should be
89#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
90#    Notice that there are two variables `i` and `t` that need to have their values provided,
91#    and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
92#    (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
93#    and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
94#
95# There are also a few optional flags in the test dict to control the C++ parity test behavior:
96#
97# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
98# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
99
100
101module_tests = [
102    dict(
103        module_name='Linear',
104        constructor_args=(10, 8),
105        cpp_constructor_args='torch::nn::LinearOptions(10, 8)',
106        input_size=(4, 10),
107        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
108        with_tf32=True,
109        tf32_precision=0.005,
110        default_dtype=torch.double,
111    ),
112    dict(
113        module_name='Linear',
114        constructor_args=(10, 8, False),
115        cpp_constructor_args='torch::nn::LinearOptions(10, 8).bias(false)',
116        input_size=(4, 10),
117        desc='no_bias',
118        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
119        with_tf32=True,
120        tf32_precision=0.005,
121        default_dtype=torch.double,
122    ),
123    dict(
124        module_name='RReLU',
125        input_size=(1, 2, 2),
126        test_cuda=False,
127        default_dtype=torch.double,
128    ),
129    dict(
130        module_name='RReLU',
131        constructor_args=(0.1, 0.9),
132        cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
133        input_size=(4, 4, 5),
134        desc='with_up_down',
135        test_cuda=False,
136        default_dtype=torch.double,
137    ),
138    dict(
139        module_name='Flatten',
140        input_size=(2, 3, 4, 5),
141        reference_fn=lambda i, *_: torch.flatten(i, 1),
142        default_dtype=torch.double,
143    ),
144    # TODO: reference function
145    dict(
146        module_name='CrossMapLRN2d',
147        constructor_args=(5, 5e-3, 1e-3, 2),
148        cpp_constructor_args='torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)',
149        input_size=(2, 3, 6, 6),
150        check_gradgrad=False,
151        # TODO(#50743): Figure out the error. "RuntimeError: Unrecognized tensor type ID: Batched"
152        check_batched_grad=False,
153        default_dtype=torch.double,
154    ),
155]
156
157
158# Generates rand tensor with non-equal values. This ensures that duplicate
159# values won't be causing test failure for modules like MaxPooling.
160# size should be small, otherwise randperm fails / long overflows.
161def _rand_tensor_non_equal(*size):
162    total = reduce(mul, size, 1)
163    return torch.randperm(total).view(*size).double()
164
165
166def wrap_functional(fn, **kwargs):
167    class FunctionalModule(nn.Module):
168        def forward(self, *args):
169            return fn(*args, **kwargs)
170    return FunctionalModule
171
172
173def poissonnllloss_no_reduce_test():
174    t = torch.randn(10, 10)
175    return dict(
176        fullname='PoissonNLLLoss_no_reduce',
177        constructor=wrap_functional(
178            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
179        cpp_function_call='F::poisson_nll_loss('
180                          'i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))',
181        input_fn=lambda: torch.rand(10, 10),
182        cpp_var_map={'i': '_get_input()', 't': t},
183        reference_fn=lambda i, *_: i.exp() - t.mul(i),
184        pickle=False,
185        default_dtype=torch.double)
186
187
188def bceloss_no_reduce_test():
189    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
190    return dict(
191        fullname='BCELoss_no_reduce',
192        constructor=wrap_functional(
193            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
194        cpp_function_call='F::binary_cross_entropy('
195                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
196        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
197        cpp_var_map={'i': '_get_input()', 't': t},
198        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
199        pickle=False,
200        precision=7e-4,
201        default_dtype=torch.double)
202
203
204def bceloss_no_reduce_scalar_test():
205    t = torch.randn(()).gt(0).to(torch.double)
206    return dict(
207        fullname='BCELoss_no_reduce_scalar',
208        constructor=wrap_functional(
209            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
210        cpp_function_call='F::binary_cross_entropy('
211                          'i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))',
212        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
213        cpp_var_map={'i': '_get_input()', 't': t},
214        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
215        pickle=False,
216        default_dtype=torch.double)
217
218
219def bceloss_weights_no_reduce_test():
220    t = Variable(torch.randn(15, 10, dtype=torch.double).gt(0).to(torch.double))
221    weights = torch.rand(10, dtype=torch.double)
222    return dict(
223        fullname='BCELoss_weights_no_reduce',
224        constructor=wrap_functional(
225            lambda i: F.binary_cross_entropy(i, t.type_as(i),
226                                             weight=weights.type_as(i), reduction='none')),
227        cpp_function_call='F::binary_cross_entropy('
228                          'i, t.to(i.options()), '
229                          'F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))',
230        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
231        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
232        reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
233        pickle=False,
234        precision=3e-4,
235        default_dtype=torch.double,
236    )
237
238
239def bceloss_weights_no_reduce_scalar_test():
240    t = torch.randn(()).gt(0).to(torch.double)
241    weights = torch.rand((), dtype=torch.double)
242    return dict(
243        fullname='BCELoss_weights_no_reduce_scalar',
244        constructor=wrap_functional(
245            lambda i: F.binary_cross_entropy(i, t.type_as(i),
246                                             weight=weights.type_as(i), reduction='none')),
247        cpp_function_call='''F::binary_cross_entropy(
248            i, t.to(i.options()),
249            F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
250        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
251        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
252        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
253        pickle=False,
254        default_dtype=torch.double,
255    )
256
257
258def bce_with_logistic_legacy_enum_test():
259    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
260    sigmoid = nn.Sigmoid()
261    return dict(
262        fullname='BCEWithLogitsLoss_legacy_enum',
263        constructor=wrap_functional(
264            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
265        cpp_function_call='''F::binary_cross_entropy_with_logits(
266            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
267        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
268        cpp_var_map={'i': '_get_input()', 't': t},
269        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
270        check_gradgrad=False,
271        pickle=False,
272        default_dtype=torch.double,
273    )
274
275
276def bce_with_logistic_no_reduce_test():
277    t = Variable(torch.randn(15, 10).gt(0).to(torch.double))
278    sigmoid = nn.Sigmoid()
279    return dict(
280        fullname='BCEWithLogitsLoss_no_reduce',
281        constructor=wrap_functional(
282            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
283        cpp_function_call='''F::binary_cross_entropy_with_logits(
284            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
285        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
286        cpp_var_map={'i': '_get_input()', 't': t},
287        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
288        check_gradgrad=False,
289        pickle=False,
290        default_dtype=torch.double,
291    )
292
293
294def bce_with_logistic_no_reduce_scalar_test():
295    t = torch.randn(()).gt(0).to(torch.double)
296    sigmoid = nn.Sigmoid()
297    return dict(
298        fullname='BCEWithLogitsLoss_no_reduce_scalar',
299        constructor=wrap_functional(
300            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
301        cpp_function_call='''F::binary_cross_entropy_with_logits(
302            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))''',
303        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
304        cpp_var_map={'i': '_get_input()', 't': t},
305        reference_fn=lambda i, *_: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
306        check_gradgrad=False,
307        pickle=False,
308        default_dtype=torch.double,
309    )
310
311
312def kldivloss_with_target_no_reduce_test():
313    t = torch.rand(10, 10, dtype=torch.double)
314    return dict(
315        fullname='KLDivLoss_with_target_no_reduce',
316        constructor=wrap_functional(
317            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
318        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
319        input_fn=lambda: torch.rand(10, 10).log(),
320        cpp_var_map={'i': '_get_input()', 't': t},
321        reference_fn=lambda i, *_:
322            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
323        supports_forward_ad=True,
324        pickle=False,
325        default_dtype=torch.double)
326
327
328def kldivloss_no_reduce_test():
329    t = torch.rand(10, 10, dtype=torch.double)
330    return dict(
331        fullname='KLDivLoss_no_reduce',
332        constructor=wrap_functional(
333            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
334        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
335        input_fn=lambda: torch.rand(10, 10).log(),
336        cpp_var_map={'i': '_get_input()', 't': t},
337        reference_fn=lambda i, *_:
338            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
339        supports_forward_ad=True,
340        pickle=False,
341        default_dtype=torch.double,
342    )
343
344
345def kldivloss_no_reduce_scalar_test():
346    t = torch.rand((), dtype=torch.double)
347    return dict(
348        fullname='KLDivLoss_no_reduce_scalar',
349        constructor=wrap_functional(
350            lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
351        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
352        input_fn=lambda: torch.rand(()).log(),
353        cpp_var_map={'i': '_get_input()', 't': t},
354        reference_fn=lambda i, *_:
355            loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
356        supports_forward_ad=True,
357        pickle=False,
358        default_dtype=torch.double)
359
360
361def kldivloss_with_log_target_no_reduce_test():
362    t = torch.rand(10, 10, dtype=torch.double).log()
363    return dict(
364        fullname='KLDivLoss_with_log_target_no_reduce',
365        constructor=wrap_functional(
366            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
367        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
368        input_fn=lambda: torch.rand(10, 10).log(),
369        cpp_var_map={'i': '_get_input()', 't': t},
370        reference_fn=lambda i, *_:
371            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
372        supports_forward_ad=True,
373        pickle=False,
374        default_dtype=torch.double)
375
376
377def kldivloss_no_reduce_log_target_test():
378    t = torch.rand(10, 10, dtype=torch.double).log()
379    return dict(
380        fullname='KLDivLoss_no_reduce_log_target',
381        constructor=wrap_functional(
382            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
383        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
384        input_fn=lambda: torch.rand(10, 10).log(),
385        cpp_var_map={'i': '_get_input()', 't': t},
386        reference_fn=lambda i, *_:
387            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
388        supports_forward_ad=True,
389        pickle=False,
390        default_dtype=torch.double,
391    )
392
393
394def kldivloss_no_reduce_scalar_log_target_test():
395    t = torch.rand((), dtype=torch.double).log()
396    return dict(
397        fullname='KLDivLoss_no_reduce_scalar_log_target',
398        constructor=wrap_functional(
399            lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
400        cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
401        input_fn=lambda: torch.rand(()).log(),
402        cpp_var_map={'i': '_get_input()', 't': t},
403        reference_fn=lambda i, *_:
404            loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
405        supports_forward_ad=True,
406        pickle=False,
407        default_dtype=torch.double)
408
409
410def l1loss_no_reduce_test():
411    t = torch.randn(2, 3, 4, dtype=torch.double)
412    return dict(
413        fullname='L1Loss_no_reduce',
414        constructor=wrap_functional(
415            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
416        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
417        input_fn=lambda: torch.randn(2, 3, 4),
418        cpp_var_map={'i': '_get_input()', 't': t},
419        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
420        supports_forward_ad=True,
421        pickle=False,
422        default_dtype=torch.double)
423
424
425def l1loss_no_reduce_complex_test():
426    t = torch.randn(2, 3, 4, dtype=torch.cdouble)
427    return dict(
428        fullname='L1Loss_no_reduce_complex',
429        constructor=wrap_functional(
430            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
431        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
432        input_fn=lambda: torch.randn(2, 3, 4, dtype=torch.cdouble),
433        cpp_var_map={'i': '_get_input()', 't': t},
434        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
435        supports_forward_ad=True,
436        pickle=False)
437
438
439def l1loss_no_reduce_scalar_test():
440    t = torch.randn((), dtype=torch.double)
441    return dict(
442        fullname='L1Loss_no_reduce_scalar',
443        constructor=wrap_functional(
444            lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
445        cpp_function_call='F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))',
446        input_fn=lambda: torch.randn(()),
447        cpp_var_map={'i': '_get_input()', 't': t},
448        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
449        supports_forward_ad=True,
450        pickle=False,
451        default_dtype=torch.double)
452
453
454def mseloss_no_reduce_test():
455    input_size = (2, 3, 4, 5)
456    target = torch.randn(*input_size, dtype=torch.double)
457    return dict(
458        fullname='MSELoss_no_reduce',
459        constructor=wrap_functional(
460            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
461        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
462        input_size=input_size,
463        cpp_var_map={'i': '_get_input()', 'target': target},
464        reference_fn=lambda i, *_: (i - target).pow(2),
465        supports_forward_ad=True,
466        pickle=False,
467        default_dtype=torch.double)
468
469
470def mseloss_no_reduce_scalar_test():
471    input_size = ()
472    target = torch.randn(input_size, dtype=torch.double)
473    return dict(
474        fullname='MSELoss_no_reduce_scalar',
475        constructor=wrap_functional(
476            lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
477        cpp_function_call='F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))',
478        input_size=input_size,
479        cpp_var_map={'i': '_get_input()', 'target': target},
480        reference_fn=lambda i, *_: (i - target).pow(2),
481        supports_forward_ad=True,
482        pickle=False,
483        default_dtype=torch.double)
484
485
486def nllloss_no_reduce_test():
487    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
488    kwargs = {'reduction': 'none'}
489    return dict(
490        fullname='NLLLoss_no_reduce',
491        constructor=wrap_functional(
492            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
493        cpp_function_call='''F::nll_loss(
494            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
495        input_fn=lambda: torch.rand(15, 10).log(),
496        cpp_var_map={'i': '_get_input()', 't': t},
497        reference_fn=lambda i, *_:
498            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
499        pickle=False,
500        default_dtype=torch.double)
501
502
503def nllloss_no_reduce_ignore_index_test():
504    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
505    kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
506    return dict(
507        fullname='NLLLoss_no_reduce_ignore_index',
508        constructor=wrap_functional(
509            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
510                                 reduction=str(kwargs['reduction']))),
511        cpp_function_call='''F::nll_loss(
512            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))''',
513        input_fn=lambda: torch.rand(15, 10).log(),
514        cpp_var_map={'i': '_get_input()', 't': t},
515        reference_fn=lambda i, *_:
516            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
517        pickle=False,
518        default_dtype=torch.double)
519
520
521def nllloss_no_reduce_weights_test():
522    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
523    weight = torch.rand(10)
524
525    def kwargs(i):
526        return {'weight': weight.type_as(i), 'reduction': 'none'}
527
528    return dict(
529        fullname='NLLLoss_no_reduce_weights',
530        constructor=wrap_functional(
531            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
532        cpp_function_call='''F::nll_loss(
533            i, t.to(i.options()).to(torch::kLong),
534            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
535        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
536        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
537        reference_fn=lambda i, *_:
538            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
539        pickle=False,
540        default_dtype=torch.double)
541
542
543def nllloss_no_reduce_weights_ignore_index_test():
544    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
545    weight = torch.rand(10)
546
547    def kwargs(i):
548        return {'weight': weight.type_as(i), 'reduction': 'none',
549                'ignore_index': 2}
550
551    return dict(
552        fullname='NLLLoss_no_reduce_weights_ignore_index',
553        constructor=wrap_functional(
554            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
555        cpp_function_call='''F::nll_loss(
556            i, t.to(i.options()).to(torch::kLong),
557            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))''',
558        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
559        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
560        reference_fn=lambda i, *_:
561            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
562        pickle=False,
563        default_dtype=torch.double)
564
565
566def nllloss_no_reduce_weights_ignore_index_neg_test():
567    t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
568    weight = torch.rand(10)
569
570    def kwargs(i):
571        return {'weight': weight.type_as(i), 'reduction': 'none',
572                'ignore_index': -1}
573
574    return dict(
575        fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
576        constructor=wrap_functional(
577            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
578        cpp_function_call='''F::nll_loss(
579            i, t.to(i.options()).to(torch::kLong),
580            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))''',
581        input=torch.rand(15, 10, dtype=torch.double).add(1e-2).log(),
582        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
583        reference_fn=lambda i, *_:
584            loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
585        pickle=False,
586        default_dtype=torch.double)
587
588
589def nllloss2d_no_reduce_test():
590    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
591    kwargs = {'reduction': 'none'}
592    return dict(
593        fullname='NLLLoss2d_no_reduce',
594        constructor=wrap_functional(
595            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
596        cpp_function_call='''F::nll_loss(
597            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
598        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
599        cpp_var_map={'i': '_get_input()', 't': t},
600        reference_fn=lambda i, *_:
601            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
602        pickle=False,
603        default_dtype=torch.double)
604
605
606def nllloss2d_no_reduce_ignore_index_test():
607    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
608    kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
609    return dict(
610        fullname='NLLLoss2d_no_reduce_ignore_index',
611        constructor=wrap_functional(
612            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
613                                 reduction=str(kwargs['reduction']))),
614        cpp_function_call='''F::nll_loss(
615            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
616        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
617        cpp_var_map={'i': '_get_input()', 't': t},
618        reference_fn=lambda i, *_:
619            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
620        pickle=False,
621        default_dtype=torch.double)
622
623
624def nllloss2d_no_reduce_weights_test():
625    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
626    weight = torch.rand(3)
627
628    def kwargs(i):
629        return {'weight': weight.type_as(i), 'reduction': 'none'}
630
631    return dict(
632        fullname='NLLLoss2d_no_reduce_weights',
633        constructor=wrap_functional(
634            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
635        cpp_function_call='''F::nll_loss(
636            i, t.to(i.options()).to(torch::kLong),
637            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
638        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
639        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
640        reference_fn=lambda i, *_:
641            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
642        pickle=False,
643        default_dtype=torch.double)
644
645
646def nlllossNd_no_reduce_test():
647    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
648    kwargs = {'reduction': 'none'}
649    return dict(
650        fullname='NLLLossNd_no_reduce',
651        constructor=wrap_functional(
652            lambda i: F.nll_loss(i, t.type_as(i).long(), reduction=kwargs['reduction'])),
653        cpp_function_call='''F::nll_loss(
654            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))''',
655        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
656        cpp_var_map={'i': '_get_input()', 't': t},
657        reference_fn=lambda i, *_:
658            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
659        pickle=False,
660        default_dtype=torch.double)
661
662
663def nlllossNd_no_reduce_ignore_index_test():
664    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
665    kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
666    return dict(
667        fullname='NLLLossNd_no_reduce_ignore_index',
668        constructor=wrap_functional(
669            lambda i: F.nll_loss(i, t.type_as(i).long(), ignore_index=int(kwargs['ignore_index']),
670                                 reduction=str(kwargs['reduction']))),
671        cpp_function_call='''F::nll_loss(
672            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))''',
673        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
674        cpp_var_map={'i': '_get_input()', 't': t},
675        reference_fn=lambda i, *_:
676            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
677        pickle=False,
678        default_dtype=torch.double)
679
680
681def nlllossNd_no_reduce_weights_test():
682    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
683    weight = torch.rand(3)
684
685    def kwargs(i):
686        return {'weight': weight.type_as(i), 'reduction': 'none'}
687
688    return dict(
689        fullname='NLLLossNd_no_reduce_weights',
690        constructor=wrap_functional(
691            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
692        cpp_function_call='''F::nll_loss(
693            i, t.to(i.options()).to(torch::kLong),
694            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))''',
695        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
696        cpp_var_map={'i': '_get_input()', 't': t, 'weight': weight},
697        reference_fn=lambda i, *_:
698            loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
699        pickle=False,
700        default_dtype=torch.double)
701
702
703def smoothl1loss_no_reduce_test():
704    t = torch.randn(2, 3, 4, dtype=torch.double)
705    return dict(
706        fullname='SmoothL1Loss_no_reduce',
707        constructor=wrap_functional(
708            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
709        cpp_function_call='''F::smooth_l1_loss(
710            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
711        input_fn=lambda: torch.randn(2, 3, 4),
712        cpp_var_map={'i': '_get_input()', 't': t},
713        reference_fn=lambda i, *_:
714            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
715        supports_forward_ad=True,
716        pickle=False,
717        default_dtype=torch.double)
718
719
720def smoothl1loss_no_reduce_scalar_test():
721    t = torch.randn((), dtype=torch.double)
722    return dict(
723        fullname='SmoothL1Loss_no_reduce_scalar',
724        constructor=wrap_functional(
725            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
726        cpp_function_call='''F::smooth_l1_loss(
727            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))''',
728        input_fn=lambda: torch.randn(()),
729        cpp_var_map={'i': '_get_input()', 't': t},
730        reference_fn=lambda i, *_:
731            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
732        supports_forward_ad=True,
733        pickle=False,
734        default_dtype=torch.double)
735
736
737def smoothl1loss_beta_test():
738    t = torch.randn(2, 3, 4, dtype=torch.double)
739    return dict(
740        fullname='SmoothL1Loss_beta',
741        constructor=wrap_functional(
742            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0.5)),
743        cpp_function_call='''F::smooth_l1_loss(
744            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)''',
745        input_fn=lambda: torch.randn(2, 3, 4),
746        cpp_var_map={'i': '_get_input()', 't': t},
747        reference_fn=lambda i, *_:
748            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0.5),
749        supports_forward_ad=True,
750        pickle=False,
751        default_dtype=torch.double)
752
753
754def smoothl1loss_zero_beta_test():
755    t = torch.randn(2, 3, 4, dtype=torch.double)
756    return dict(
757        fullname='SmoothL1Loss_zero_beta',
758        constructor=wrap_functional(
759            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none', beta=0)),
760        cpp_function_call='''F::smooth_l1_loss(
761            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)''',
762        input_fn=lambda: torch.randn(2, 3, 4),
763        cpp_var_map={'i': '_get_input()', 't': t},
764        reference_fn=lambda i, *_:
765            loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none', beta=0),
766        supports_forward_ad=True,
767        pickle=False,
768        default_dtype=torch.double)
769
770
771def huberloss_delta_test():
772    t = torch.randn(2, 3, 4)
773    return dict(
774        fullname='HuberLoss_delta',
775        constructor=wrap_functional(
776            lambda i: F.huber_loss(i, t.type_as(i), reduction='none', delta=0.5)),
777        cpp_function_call='''F::huber_loss(
778            i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))''',
779        input_fn=lambda: torch.randn(2, 3, 4),
780        cpp_var_map={'i': '_get_input()', 't': t},
781        reference_fn=lambda i, *_:
782            loss_reference_fns['HuberLoss'](i, t.type_as(i), reduction='none', delta=0.5),
783        supports_forward_ad=True,
784        pickle=False,
785        default_dtype=torch.double)
786
787
788def multilabelmarginloss_0d_no_reduce_test():
789    t = torch.zeros(()).long()
790    return dict(
791        fullname='MultiLabelMarginLoss_0d_no_reduce',
792        constructor=wrap_functional(
793            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
794        cpp_function_call='''F::multilabel_margin_loss(
795            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
796        input_fn=lambda: torch.randn(()),
797        cpp_var_map={'i': '_get_input()', 't': t},
798        reference_fn=lambda i, *_:
799            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
800        check_sum_reduction=True,
801        check_gradgrad=False,
802        pickle=False)
803
804
805def multilabelmarginloss_1d_no_reduce_test():
806    t = Variable(torch.rand(10).mul(10).floor().long())
807    return dict(
808        fullname='MultiLabelMarginLoss_1d_no_reduce',
809        constructor=wrap_functional(
810            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
811        cpp_function_call='''F::multilabel_margin_loss(
812            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
813        input_fn=lambda: torch.randn(10),
814        cpp_var_map={'i': '_get_input()', 't': t},
815        reference_fn=lambda i, *_:
816            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
817        check_sum_reduction=True,
818        check_gradgrad=False,
819        pickle=False,
820        default_dtype=torch.double)
821
822
823def multilabelmarginloss_index_neg_test():
824    t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
825    return dict(
826        fullname='MultiLabelMarginLoss_index_neg',
827        constructor=wrap_functional(
828            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
829        cpp_function_call='''F::multilabel_margin_loss(
830            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
831        input_fn=lambda: torch.randn(5, 10),
832        cpp_var_map={'i': '_get_input()', 't': t},
833        reference_fn=lambda i, *_:
834            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
835        check_sum_reduction=True,
836        check_gradgrad=False,
837        pickle=False,
838        default_dtype=torch.double)
839
840
841def multilabelmarginloss_no_reduce_test():
842    t = Variable(torch.rand(5, 10).mul(10).floor().long())
843    return dict(
844        fullname='MultiLabelMarginLoss_no_reduce',
845        constructor=wrap_functional(
846            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
847        cpp_function_call='''F::multilabel_margin_loss(
848            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))''',
849        input_fn=lambda: torch.randn(5, 10),
850        cpp_var_map={'i': '_get_input()', 't': t},
851        reference_fn=lambda i, *_:
852            loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
853        check_sum_reduction=True,
854        check_gradgrad=False,
855        pickle=False,
856        default_dtype=torch.double)
857
858
859def hingeembeddingloss_no_reduce_test():
860    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
861    return dict(
862        fullname='HingeEmbeddingLoss_no_reduce',
863        constructor=wrap_functional(
864            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
865        cpp_function_call='''F::hinge_embedding_loss(
866            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))''',
867        input_fn=lambda: torch.randn(10),
868        cpp_var_map={'i': '_get_input()', 't': t},
869        reference_fn=lambda i, *_:
870            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
871        check_sum_reduction=True,
872        pickle=False,
873        default_dtype=torch.double)
874
875
876def hingeembeddingloss_margin_no_reduce_test():
877    t = Variable(torch.randn(10).gt(0).to(torch.double).mul_(2).sub(1))
878    return dict(
879        fullname='HingeEmbeddingLoss_margin_no_reduce',
880        constructor=wrap_functional(
881            lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
882        cpp_function_call='''F::hinge_embedding_loss(
883            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
884        input_fn=lambda: torch.randn(10),
885        cpp_var_map={'i': '_get_input()', 't': t},
886        reference_fn=lambda i, *_:
887            loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
888        check_sum_reduction=True,
889        pickle=False,
890        default_dtype=torch.double)
891
892
893def softmarginloss_no_reduce_test():
894    t = torch.randn(5, 5, dtype=torch.double)
895    return dict(
896        fullname='SoftMarginLoss_no_reduce',
897        constructor=wrap_functional(
898            lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
899        cpp_function_call='''F::soft_margin_loss(
900            i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))''',
901        input_fn=lambda: torch.randn(5, 5),
902        cpp_var_map={'i': '_get_input()', 't': t},
903        reference_fn=lambda i, *_:
904            loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
905        supports_forward_ad=True,
906        pickle=False,
907        default_dtype=torch.double)
908
909
910def multilabelsoftmarginloss_no_reduce_test():
911    t = torch.rand(5, 10).mul(2).floor()
912    return dict(
913        fullname='MultiLabelSoftMarginLoss_no_reduce',
914        constructor=wrap_functional(
915            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
916        cpp_function_call='''F::multilabel_soft_margin_loss(
917            i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))''',
918        input_fn=lambda: torch.randn(5, 10),
919        cpp_var_map={'i': '_get_input()', 't': t},
920        reference_fn=lambda i, *_:
921            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
922        check_gradgrad=False,
923        pickle=False,
924        default_dtype=torch.double)
925
926
927def multilabelsoftmarginloss_weights_no_reduce_test():
928    t = torch.rand(5, 10).mul(2).floor()
929    weights = torch.rand(10)
930    return dict(
931        fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
932        constructor=wrap_functional(
933            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
934                                                    weight=weights.type_as(i), reduction='none')),
935        cpp_function_call='''F::multilabel_soft_margin_loss(
936            i, t.to(i.options()),
937            F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
938        input_fn=lambda: torch.randn(5, 10),
939        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
940        reference_fn=lambda i, *_:
941            (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
942        check_sum_reduction=True,
943        check_gradgrad=False,
944        pickle=False,
945        default_dtype=torch.double)
946
947
948def multimarginloss_no_reduce_test():
949    t = torch.rand(5).mul(8).floor().long()
950    return dict(
951        fullname='MultiMarginLoss_no_reduce',
952        constructor=wrap_functional(
953            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
954        cpp_function_call='''F::multi_margin_loss(
955            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
956        input_fn=lambda: torch.randn(5, 10),
957        cpp_var_map={'i': '_get_input()', 't': t},
958        reference_fn=lambda i, *_:
959            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
960        check_sum_reduction=True,
961        check_gradgrad=False,
962        pickle=False,
963        default_dtype=torch.double)
964
965
966def multimarginloss_1d_no_reduce_test():
967    t = torch.rand(1).mul(8).floor().long()
968    return dict(
969        fullname='MultiMarginLoss_1d_no_reduce',
970        constructor=wrap_functional(
971            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
972        cpp_function_call='''F::multi_margin_loss(
973            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
974        input_fn=lambda: torch.randn(10),
975        cpp_var_map={'i': '_get_input()', 't': t},
976        reference_fn=lambda i, *_:
977            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
978        check_sum_reduction=True,
979        check_gradgrad=False,
980        pickle=False,
981        default_dtype=torch.double)
982
983
984def multimarginloss_1d_input_0d_target_no_reduce_test():
985    t = torch.rand(()).mul(8).floor().long()
986    return dict(
987        fullname='multimarginloss_1d_input_0d_target_no_reduce',
988        constructor=wrap_functional(
989            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
990        cpp_function_call='''F::multi_margin_loss(
991            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))''',
992        input_fn=lambda: torch.randn(10),
993        cpp_var_map={'i': '_get_input()', 't': t},
994        reference_fn=lambda i, *_:
995            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
996        check_sum_reduction=True,
997        check_gradgrad=False,
998        pickle=False,
999        default_dtype=torch.double)
1000
1001
1002def multimarginloss_p_no_reduce_test():
1003    t = torch.rand(5).mul(8).floor().long()
1004    return dict(
1005        fullname='MultiMarginLoss_p_no_reduce',
1006        constructor=wrap_functional(
1007            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
1008        cpp_function_call='''F::multi_margin_loss(
1009            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))''',
1010        input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
1011        cpp_var_map={'i': '_get_input()', 't': t},
1012        reference_fn=lambda i, *_:
1013            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
1014        check_sum_reduction=True,
1015        check_gradgrad=False,
1016        pickle=False,
1017        default_dtype=torch.double)
1018
1019
1020def multimarginloss_margin_no_reduce_test():
1021    t = torch.rand(5).mul(8).floor().long()
1022    return dict(
1023        fullname='MultiMarginLoss_margin_no_reduce',
1024        constructor=wrap_functional(
1025            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
1026        cpp_function_call='''F::multi_margin_loss(
1027            i, t.to(i.options()).to(torch::kLong),
1028            F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))''',
1029        input_fn=lambda: torch.randn(5, 10),
1030        cpp_var_map={'i': '_get_input()', 't': t},
1031        reference_fn=lambda i, *_:
1032            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
1033                                                  margin=0.5, reduction='none'),
1034        check_sum_reduction=True,
1035        check_gradgrad=False,
1036        pickle=False,
1037        default_dtype=torch.double)
1038
1039
1040def multimarginloss_weights_no_reduce_test():
1041    t = torch.rand(5).mul(8).floor().long()
1042    weights = torch.rand(10, dtype=torch.double)
1043    return dict(
1044        fullname='MultiMarginLoss_weights_no_reduce',
1045        constructor=wrap_functional(
1046            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
1047                                          reduction='none')),
1048        cpp_function_call='''F::multi_margin_loss(
1049            i, t.to(i.options()).to(torch::kLong),
1050            F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))''',
1051        input_fn=lambda: torch.randn(5, 10),
1052        cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights},
1053        reference_fn=lambda i, *_:
1054            loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
1055                                                  weight=weights, reduction='none'),
1056        check_sum_reduction=True,
1057        check_gradgrad=False,
1058        pickle=False,
1059        default_dtype=torch.double)
1060
1061
1062def single_batch_reference_fn(input, parameters, module):
1063    """Reference function for modules supporting no batch dimensions.
1064
1065    The module is passed the input and target in batched form with a single item.
1066    The output is squeezed to compare with the no-batch input.
1067    """
1068    def unsqueeze_inp(inp):
1069        if isinstance(inp, (list, tuple)):
1070            return [t.unsqueeze(0) for t in inp]
1071        return inp.unsqueeze(0)
1072
1073    single_batch_input = unsqueeze_inp(input)
1074    single_batch_input = [single_batch_input] if isinstance(single_batch_input, torch.Tensor) else single_batch_input
1075    with freeze_rng_state():
1076        return module(*single_batch_input).squeeze(0)
1077
1078
1079new_module_tests = [
1080    poissonnllloss_no_reduce_test(),
1081    bceloss_no_reduce_test(),
1082    bceloss_weights_no_reduce_test(),
1083    bce_with_logistic_legacy_enum_test(),
1084    bce_with_logistic_no_reduce_test(),
1085    bceloss_no_reduce_scalar_test(),
1086    bceloss_weights_no_reduce_scalar_test(),
1087    bce_with_logistic_no_reduce_scalar_test(),
1088    kldivloss_with_target_no_reduce_test(),
1089    kldivloss_no_reduce_test(),
1090    kldivloss_no_reduce_scalar_test(),
1091    kldivloss_with_log_target_no_reduce_test(),
1092    kldivloss_no_reduce_log_target_test(),
1093    kldivloss_no_reduce_scalar_log_target_test(),
1094    l1loss_no_reduce_test(),
1095    l1loss_no_reduce_complex_test(),
1096    l1loss_no_reduce_scalar_test(),
1097    mseloss_no_reduce_test(),
1098    mseloss_no_reduce_scalar_test(),
1099    nllloss_no_reduce_test(),
1100    nllloss_no_reduce_ignore_index_test(),
1101    nllloss_no_reduce_weights_test(),
1102    nllloss_no_reduce_weights_ignore_index_test(),
1103    nllloss_no_reduce_weights_ignore_index_neg_test(),
1104    nllloss2d_no_reduce_test(),
1105    nllloss2d_no_reduce_weights_test(),
1106    nllloss2d_no_reduce_ignore_index_test(),
1107    nlllossNd_no_reduce_test(),
1108    nlllossNd_no_reduce_weights_test(),
1109    nlllossNd_no_reduce_ignore_index_test(),
1110    smoothl1loss_no_reduce_test(),
1111    smoothl1loss_no_reduce_scalar_test(),
1112    smoothl1loss_beta_test(),
1113    smoothl1loss_zero_beta_test(),
1114    huberloss_delta_test(),
1115    multilabelmarginloss_0d_no_reduce_test(),
1116    multilabelmarginloss_1d_no_reduce_test(),
1117    multilabelmarginloss_index_neg_test(),
1118    multilabelmarginloss_no_reduce_test(),
1119    hingeembeddingloss_no_reduce_test(),
1120    hingeembeddingloss_margin_no_reduce_test(),
1121    softmarginloss_no_reduce_test(),
1122    multilabelsoftmarginloss_no_reduce_test(),
1123    multilabelsoftmarginloss_weights_no_reduce_test(),
1124    multimarginloss_no_reduce_test(),
1125    multimarginloss_1d_no_reduce_test(),
1126    multimarginloss_1d_input_0d_target_no_reduce_test(),
1127    multimarginloss_p_no_reduce_test(),
1128    multimarginloss_margin_no_reduce_test(),
1129    multimarginloss_weights_no_reduce_test(),
1130    dict(
1131        module_name='Conv1d',
1132        constructor_args=(4, 5, 3),
1133        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
1134        input_size=(2, 4, 10),
1135        cudnn=True,
1136        with_tf32=True,
1137        tf32_precision=0.005,
1138        default_dtype=torch.double,
1139    ),
1140    dict(
1141        module_name='Conv1d',
1142        constructor_args=(4, 5, 3, 2),
1143        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(2)',
1144        input_size=(2, 4, 10),
1145        cudnn=True,
1146        desc='stride',
1147        with_tf32=True,
1148        tf32_precision=0.005,
1149        default_dtype=torch.double,
1150    ),
1151    dict(
1152        module_name='Conv1d',
1153        constructor_args=(4, 5, 3, 1, 1),
1154        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)',
1155        input_size=(2, 4, 10),
1156        cudnn=True,
1157        desc='pad1',
1158        with_tf32=True,
1159        tf32_precision=0.01,
1160        default_dtype=torch.double,
1161    ),
1162    dict(
1163        module_name='Conv1d',
1164        constructor_args=(4, 5, 5, 1, 2),
1165        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)',
1166        input_size=(2, 4, 10),
1167        cudnn=True,
1168        desc='pad2',
1169        with_tf32=True,
1170        tf32_precision=0.005,
1171        default_dtype=torch.double,
1172    ),
1173    dict(
1174        module_name='Conv1d',
1175        constructor_args=(4, 4, 3, 1, 1),
1176        cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)',
1177        input_size=(1, 4, 1),
1178        cudnn=True,
1179        desc='pad1size1',
1180        with_tf32=True,
1181        tf32_precision=0.005,
1182        default_dtype=torch.double,
1183    ),
1184    dict(
1185        module_name='Conv1d',
1186        constructor_args=(4, 4, 5, 1, 2),
1187        cpp_constructor_args='torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)',
1188        input_size=(1, 4, 1),
1189        cudnn=True,
1190        desc='pad2size1',
1191        with_tf32=True,
1192        tf32_precision=0.005,
1193        default_dtype=torch.double,
1194    ),
1195    dict(
1196        module_name='Conv1d',
1197        constructor_args=(4, 5, 3),
1198        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3)',
1199        input_size=(0, 4, 10),
1200        cudnn=True,
1201        desc='zero_batch',
1202        with_tf32=True,
1203        tf32_precision=0.005,
1204    ),
1205    dict(
1206        fullname='Conv1d_dilated',
1207        constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
1208        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
1209        input_size=(2, 4, 10),
1210        with_tf32=True,
1211        tf32_precision=0.005,
1212        default_dtype=torch.double,
1213    ),
1214    dict(
1215        fullname='Conv1d_groups',
1216        constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
1217        cpp_constructor_args='torch::nn::Conv1dOptions(4, 6, 3).groups(2)',
1218        input_size=(2, 4, 6),
1219        cudnn=True,
1220        with_tf32=True,
1221        tf32_precision=0.005,
1222        default_dtype=torch.double,
1223    ),
1224    dict(
1225        fullname='Conv1d_pad_valid',
1226        constructor=lambda: nn.Conv1d(4, 5, 3, padding="valid"),
1227        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)',
1228        input_size=(2, 4, 10),
1229        cudnn=True,
1230        with_tf32=True,
1231        tf32_precision=0.005,
1232        default_dtype=torch.double,
1233    ),
1234    dict(
1235        fullname='Conv1d_pad_same',
1236        constructor=lambda: nn.Conv1d(4, 5, 3, padding="same"),
1237        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)',
1238        input_size=(2, 4, 10),
1239        cudnn=True,
1240        with_tf32=True,
1241        tf32_precision=0.005,
1242        default_dtype=torch.double,
1243    ),
1244    dict(
1245        fullname='Conv1d_pad_same2',
1246        constructor=lambda: nn.Conv1d(4, 5, 4, padding="same"),
1247        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)',
1248        input_size=(2, 4, 10),
1249        cudnn=True,
1250        with_tf32=True,
1251        tf32_precision=0.005,
1252        default_dtype=torch.double,
1253    ),
1254    dict(
1255        fullname='Conv1d_pad_same_dilated',
1256        constructor=lambda: nn.Conv1d(4, 5, 4, padding="same", dilation=2),
1257        cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)',
1258        input_size=(2, 4, 10),
1259        cudnn=True,
1260        with_tf32=True,
1261        tf32_precision=0.005,
1262        default_dtype=torch.double,
1263    ),
1264    dict(
1265        fullname='ConvTranspose1d',
1266        constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
1267        cpp_constructor_args='torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)',
1268        cudnn=True,
1269        input_size=(1, 3, 7),
1270        with_tf32=True,
1271        tf32_precision=0.005,
1272        default_dtype=torch.double,
1273    ),
1274    dict(
1275        module_name='ConvTranspose1d',
1276        constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
1277        cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
1278                                .stride(2).padding(1).output_padding(1).groups(1).bias(false)''',
1279        input_size=(1, 3, 6),
1280        cudnn=True,
1281        desc='no_bias',
1282        with_tf32=True,
1283        tf32_precision=0.005,
1284        default_dtype=torch.double,
1285    ),
1286    dict(
1287        module_name='ConvTranspose1d',
1288        constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
1289        cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(3, 4, 3)
1290                                .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)''',
1291        input_size=(1, 3, 6),
1292        cudnn=True,
1293        desc='dilated',
1294        with_tf32=True,
1295        tf32_precision=0.005,
1296        default_dtype=torch.double,
1297    ),
1298    dict(
1299        fullname='ConvTranspose1d_groups',
1300        constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
1301        cpp_constructor_args='''torch::nn::ConvTranspose1dOptions(4, 6, 3)
1302                                .stride(3).padding(1).output_padding(1).groups(2)''',
1303        cudnn=True,
1304        input_size=(2, 4, 7),
1305        with_tf32=True,
1306        tf32_precision=0.005,
1307        default_dtype=torch.double,
1308    ),
1309    dict(
1310        module_name='Conv2d',
1311        constructor_args=(3, 4, (3, 2)),
1312        cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
1313        input_size=(2, 3, 7, 5),
1314        cudnn=True,
1315        check_with_long_tensor=True,
1316        with_tf32=True,
1317        tf32_precision=0.005,
1318        default_dtype=torch.double,
1319    ),
1320    dict(
1321        module_name='Conv2d',
1322        constructor_args=(3, 4, (3, 3), (2, 2)),
1323        cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})',
1324        input_size=(2, 3, 6, 6),
1325        cudnn=True,
1326        desc='strided',
1327        check_with_long_tensor=True,
1328        with_tf32=True,
1329        tf32_precision=0.005,
1330        default_dtype=torch.double,
1331    ),
1332    dict(
1333        module_name='Conv2d',
1334        constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
1335        cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})',
1336        input_size=(2, 3, 6, 6),
1337        cudnn=True,
1338        desc='padding',
1339        check_with_long_tensor=True,
1340        with_tf32=True,
1341        tf32_precision=0.005,
1342        default_dtype=torch.double,
1343    ),
1344    dict(
1345        module_name='Conv2d',
1346        constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
1347        cpp_constructor_args='torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})',
1348        input_size=(2, 3, 8, 8),
1349        cudnn=True,
1350        desc='dilated',
1351        check_with_long_tensor=True,
1352        with_tf32=True,
1353        tf32_precision=0.005,
1354        default_dtype=torch.double,
1355    ),
1356    dict(
1357        module_name='Conv2d',
1358        constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
1359        cpp_constructor_args='''torch::nn::Conv2dOptions(3, 4, {3, 2})
1360                                .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
1361        input_size=(2, 3, 6, 5),
1362        cudnn=True,
1363        desc='no_bias',
1364        check_with_long_tensor=True,
1365        with_tf32=True,
1366        tf32_precision=0.015,
1367        default_dtype=torch.double,
1368    ),
1369    dict(
1370        module_name='Conv2d',
1371        constructor_args=(3, 4, (3, 2)),
1372        cpp_constructor_args='torch::nn::Conv2dOptions(3, 4, {3, 2})',
1373        input_size=(0, 3, 7, 5),
1374        cudnn=True,
1375        desc='zero_batch',
1376        check_with_long_tensor=True,
1377        with_tf32=True,
1378    ),
1379    dict(
1380        fullname='Conv2d_groups',
1381        constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1382        cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
1383        input_size=(2, 4, 6, 5),
1384        cudnn=True,
1385        check_with_long_tensor=True,
1386        with_tf32=True,
1387        tf32_precision=0.015,
1388        default_dtype=torch.double,
1389    ),
1390    dict(
1391        fullname='Conv2d_groups_thnn',
1392        constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1393        cpp_constructor_args='torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)',
1394        input_size=(2, 4, 6, 5),
1395        check_with_long_tensor=True,
1396        with_tf32=True,
1397        tf32_precision=0.015,
1398        default_dtype=torch.double,
1399    ),
1400    dict(
1401        fullname='Conv2d_pad_valid',
1402        constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="valid"),
1403        cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid)',
1404        input_size=(2, 2, 6, 5),
1405        cudnn=True,
1406        with_tf32=True,
1407        tf32_precision=0.005,
1408        default_dtype=torch.double,
1409    ),
1410    dict(
1411        fullname='Conv2d_pad_same',
1412        constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same"),
1413        cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)',
1414        input_size=(2, 2, 6, 5),
1415        cudnn=True,
1416        with_tf32=True,
1417        tf32_precision=0.01,
1418        default_dtype=torch.double,
1419    ),
1420    dict(
1421        fullname='Conv2d_pad_same_dilated',
1422        constructor=lambda: nn.Conv2d(2, 4, (3, 4), padding="same", dilation=2),
1423        cpp_constructor_args='torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)',
1424        input_size=(2, 2, 6, 5),
1425        cudnn=True,
1426        with_tf32=True,
1427        tf32_precision=0.01,
1428        default_dtype=torch.double,
1429    ),
1430    dict(
1431        module_name='ConvTranspose2d',
1432        constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
1433        cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
1434                                .stride({3, 2}).padding(1).output_padding({1, 1})''',
1435        cudnn=True,
1436        input_size=(1, 3, 7, 6),
1437        check_with_long_tensor=True,
1438        with_tf32=True,
1439        tf32_precision=0.01,
1440        default_dtype=torch.double,
1441    ),
1442    dict(
1443        module_name='ConvTranspose2d',
1444        constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
1445        cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
1446                                .stride({2, 3})
1447                                .padding(1)
1448                                .output_padding({1, 1})
1449                                .groups(1)
1450                                .bias(false)
1451                                .dilation({2, 2})''',
1452        input_size=(1, 3, 6, 7),
1453        cudnn=True,
1454        desc='dilated',
1455        check_with_long_tensor=True,
1456        with_tf32=True,
1457        tf32_precision=0.01,
1458        default_dtype=torch.double,
1459    ),
1460    dict(
1461        module_name='ConvTranspose2d',
1462        constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
1463        cpp_constructor_args='''torch::nn::ConvTranspose2dOptions(3, 4, 3)
1464                                .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)''',
1465        input_size=(1, 3, 6, 7),
1466        cudnn=True,
1467        desc='no_bias',
1468        check_with_long_tensor=True,
1469        with_tf32=True,
1470        tf32_precision=0.01,
1471        default_dtype=torch.double,
1472    ),
1473    dict(
1474        fullname='ConvTranspose2d_groups',
1475        constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
1476        cpp_constructor_args='torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)',
1477        input_size=(1, 2, 4, 5),
1478        cudnn=True,
1479        check_with_long_tensor=True,
1480        with_tf32=True,
1481        tf32_precision=0.01,
1482        default_dtype=torch.double,
1483    ),
1484    dict(
1485        fullname='Conv2d_depthwise',
1486        constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
1487        cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)',
1488        input_size=(2, 4, 6, 6),
1489        with_tf32=True,
1490        tf32_precision=0.005,
1491        default_dtype=torch.double,
1492    ),
1493    dict(
1494        fullname='Conv2d_depthwise_with_multiplier',
1495        constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
1496        cpp_constructor_args='torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)',
1497        input_size=(2, 4, 6, 6),
1498        with_tf32=True,
1499        tf32_precision=0.005,
1500        default_dtype=torch.double,
1501    ),
1502    dict(
1503        fullname='Conv2d_depthwise_strided',
1504        constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
1505        cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)',
1506        input_size=(2, 4, 6, 6),
1507        with_tf32=True,
1508        tf32_precision=0.005,
1509        default_dtype=torch.double,
1510    ),
1511    dict(
1512        fullname='Conv2d_depthwise_padded',
1513        constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
1514        cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)',
1515        input_size=(2, 4, 6, 6),
1516        with_tf32=True,
1517        tf32_precision=0.005,
1518        default_dtype=torch.double,
1519    ),
1520    dict(
1521        fullname='Conv2d_depthwise_dilated',
1522        constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
1523        cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
1524        input_size=(2, 4, 5, 5),
1525        with_tf32=True,
1526        tf32_precision=0.005,
1527        default_dtype=torch.double,
1528    ),
1529    dict(
1530        module_name='Conv3d',
1531        constructor_args=(2, 3, (2, 3, 2)),
1532        cpp_constructor_args='torch::nn::Conv3dOptions(2, 3, {2, 3, 2})',
1533        input_size=(1, 2, 4, 5, 4),
1534        cudnn=True,
1535        check_with_long_tensor=True,
1536        with_tf32=True,
1537        tf32_precision=0.05,
1538        default_dtype=torch.double,
1539    ),
1540    dict(
1541        module_name='Conv3d',
1542        constructor_args=(2, 3, (2, 3, 4), 1, 0, 1, 1, False),
1543        cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
1544                                .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
1545        input_size=(1, 2, 3, 4, 5),
1546        cudnn=True,
1547        desc='no_bias',
1548        check_with_long_tensor=True,
1549        with_tf32=True,
1550        tf32_precision=0.05,
1551        default_dtype=torch.double,
1552    ),
1553    dict(
1554        module_name='Conv3d',
1555        constructor_args=(2, 3, (1, 1, 1), 1, 0, 1, 1, False),
1556        cpp_constructor_args='''torch::nn::Conv3dOptions(2, 3, {2, 3, 4})
1557                                .stride(1).padding(0).dilation(1).groups(1).bias(false)''',
1558        input_size=(1, 2, 3, 4, 5),
1559        cudnn=True,
1560        desc='1x1x1_no_bias',
1561        check_with_long_tensor=False,
1562        with_tf32=True,
1563        tf32_precision=0.05,
1564        default_dtype=torch.double,
1565    ),
1566    dict(
1567        module_name='Conv3d',
1568        constructor_args=(3, 4, 2, 2),
1569        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2)',
1570        input_size=(2, 3, 5, 5, 5),
1571        cudnn=True,
1572        desc='stride',
1573        check_with_long_tensor=True,
1574        with_tf32=True,
1575        tf32_precision=0.05,
1576        default_dtype=torch.double,
1577    ),
1578    dict(
1579        module_name='Conv3d',
1580        constructor_args=(3, 4, 2, 2, 1),
1581        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)',
1582        input_size=(2, 3, 5, 5, 5),
1583        cudnn=True,
1584        desc='stride_padding',
1585        check_with_long_tensor=True,
1586        with_tf32=True,
1587        tf32_precision=0.05,
1588        default_dtype=torch.double,
1589    ),
1590    dict(
1591        module_name='Conv3d',
1592        constructor_args=(3, 4, (2, 3, 4)),
1593        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4})',
1594        input_size=(0, 3, 3, 4, 5),
1595        cudnn=True,
1596        check_with_long_tensor=True,
1597        desc='zero_batch',
1598        with_tf32=True,
1599    ),
1600    dict(
1601        fullname='Conv3d_groups',
1602        constructor=lambda: nn.Conv3d(2, 4, kernel_size=3, groups=2),
1603        cpp_constructor_args='torch::nn::Conv3dOptions(2, 4, 3).groups(2)',
1604        input_size=(1, 2, 4, 5, 4),
1605        cudnn=True,
1606        check_with_long_tensor=True,
1607        with_tf32=True,
1608        tf32_precision=0.005,
1609        default_dtype=torch.double,
1610    ),
1611    dict(
1612        fullname='Conv3d_dilated',
1613        constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
1614        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
1615        input_size=(2, 3, 5, 5, 5),
1616        with_tf32=True,
1617        tf32_precision=0.05,
1618        default_dtype=torch.double,
1619    ),
1620    dict(
1621        fullname='Conv3d_dilated_strided',
1622        constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
1623        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
1624        input_size=(2, 3, 5, 5, 5),
1625        with_tf32=True,
1626        tf32_precision=0.05,
1627        default_dtype=torch.double,
1628    ),
1629    dict(
1630        fullname='Conv3d_pad_valid',
1631        constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="valid"),
1632        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid)',
1633        input_size=(2, 3, 6, 5, 4),
1634        cudnn=True,
1635        with_tf32=True,
1636        tf32_precision=0.05,
1637        default_dtype=torch.double,
1638    ),
1639    dict(
1640        fullname='Conv3d_pad_same',
1641        constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same"),
1642        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)',
1643        input_size=(2, 3, 6, 5, 4),
1644        cudnn=True,
1645        with_tf32=True,
1646        tf32_precision=0.05,
1647        default_dtype=torch.double,
1648    ),
1649    dict(
1650        fullname='Conv3d_pad_same_dilated',
1651        constructor=lambda: nn.Conv3d(3, 4, (2, 3, 4), padding="same", dilation=2),
1652        cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)',
1653        input_size=(2, 3, 6, 5, 4),
1654        cudnn=True,
1655        with_tf32=True,
1656        tf32_precision=0.05,
1657        default_dtype=torch.double,
1658    ),
1659    dict(
1660        module_name='ConvTranspose3d',
1661        constructor_args=(2, 3, (2, 3, 2)),
1662        cpp_constructor_args='torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})',
1663        cudnn=True,
1664        input_size=(1, 2, 4, 5, 4),
1665        with_tf32=True,
1666        tf32_precision=0.05,
1667        default_dtype=torch.double,
1668    ),
1669    dict(
1670        module_name='ConvTranspose3d',
1671        constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
1672        cpp_constructor_args='''torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
1673                                .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})''',
1674        cudnn=True,
1675        input_size=(1, 2, 4, 5, 4),
1676        desc='dilated',
1677        with_tf32=True,
1678        tf32_precision=0.05,
1679        default_dtype=torch.double,
1680    ),
1681    dict(
1682        module_name='ReplicationPad3d',
1683        constructor_args=((1, 2, 3, 3, 2, 1),),
1684        cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
1685        input_size=(2, 3, 2, 2, 2),
1686        default_dtype=torch.double,
1687    ),
1688    dict(
1689        module_name='ReplicationPad3d',
1690        constructor_args=((1, 2, 3, 3, 2, 1),),
1691        cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
1692        input_size=(3, 2, 2, 2),
1693        reference_fn=single_batch_reference_fn,
1694        desc='no_batch_dim',
1695        default_dtype=torch.double,
1696    ),
1697    dict(
1698        module_name='ReplicationPad3d',
1699        constructor_args=((1, 2, 3, 3, 2, 1),),
1700        cpp_constructor_args='torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1})',
1701        input_fn=lambda: torch.rand(2, 3, 2, 2, 2, dtype=torch.complex128, requires_grad=True),
1702        skip_half=True,
1703        desc='complex'
1704    ),
1705    dict(
1706        module_name='Embedding',
1707        constructor_args=(4, 3),
1708        cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
1709        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1710        check_gradgrad=False,
1711        default_dtype=torch.double,
1712        decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
1713    ),
1714    dict(
1715        module_name='Embedding',
1716        constructor_args=(4, 3),
1717        cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
1718        input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
1719        check_gradgrad=False,
1720        desc='discontiguous',
1721        default_dtype=torch.double,
1722        decorator=skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/117971")
1723    ),
1724    dict(
1725        module_name='EmbeddingBag',
1726        constructor_args=(4, 3),
1727        cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
1728        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1729        check_gradgrad=False,
1730        desc='mean',
1731        default_dtype=torch.double,
1732    ),
1733    dict(
1734        module_name='EmbeddingBag',
1735        constructor_args=(4, 3),
1736        cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3)',
1737        input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
1738        check_gradgrad=False,
1739        desc='discontiguous',
1740        default_dtype=torch.double,
1741    ),
1742    dict(
1743        module_name='EmbeddingBag',
1744        constructor_args=(4, 3, None, 2., False, 'sum'),
1745        cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
1746                                .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)''',
1747        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1748        check_gradgrad=False,
1749        desc='sum',
1750        default_dtype=torch.double,
1751    ),
1752    dict(
1753        module_name='EmbeddingBag',
1754        constructor_args=(4, 3, None, 2., False, 'max'),
1755        cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
1756                                .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)''',
1757        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1758        check_gradgrad=False,
1759        desc='max',
1760        default_dtype=torch.double,
1761    ),
1762    dict(
1763        fullname='EmbeddingBag_mean_padding_idx',
1764        constructor=lambda: nn.EmbeddingBag(4, 3, padding_idx=1),
1765        cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)',
1766        input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
1767        check_gradgrad=False,
1768        default_dtype=torch.double,
1769    ),
1770    dict(
1771        fullname='EmbeddingBag_sum_padding_idx',
1772        constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'sum', padding_idx=1),
1773        cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
1774                                .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)''',
1775        input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
1776        check_gradgrad=False,
1777        default_dtype=torch.double,
1778    ),
1779    dict(
1780        fullname='EmbeddingBag_max_padding_idx',
1781        constructor=lambda: nn.EmbeddingBag(4, 3, None, 2., False, 'max', padding_idx=1),
1782        cpp_constructor_args='''torch::nn::EmbeddingBagOptions(4, 3)
1783                                .max_norm(std::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)''',
1784        input_fn=lambda: torch.stack([torch.randperm(3), torch.randperm(3)]),
1785        check_gradgrad=False,
1786        default_dtype=torch.double,
1787    ),
1788    dict(
1789        fullname='EmbeddingBag_sparse',
1790        constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True, dtype=torch.double),
1791        cpp_constructor_args='torch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
1792        input_fn=lambda: torch.randperm(2).repeat(1, 2),
1793        check_gradgrad=False,
1794        has_sparse_gradients=True,
1795    ),
1796    dict(
1797        constructor=lambda: nn.Embedding(4, 3, dtype=torch.double, sparse=True),
1798        cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))',
1799        input_fn=lambda: torch.randperm(2).repeat(1, 2),
1800        fullname='Embedding_sparse',
1801        check_gradgrad=False,
1802        has_sparse_gradients=True,
1803    ),
1804    dict(
1805        module_name='PixelShuffle',
1806        constructor_args=(3,),
1807        cpp_constructor_args='torch::nn::PixelShuffleOptions(3)',
1808        input_size=(1, 9, 4, 4),
1809        default_dtype=torch.double,
1810    ),
1811    dict(
1812        module_name='PixelUnshuffle',
1813        constructor_args=(3,),
1814        cpp_constructor_args='torch::nn::PixelUnshuffleOptions(3)',
1815        input_size=(1, 1, 12, 12),
1816        default_dtype=torch.double,
1817    ),
1818    dict(
1819        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
1820        cpp_options_args='''F::InterpolateFuncOptions()
1821                            .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
1822        input_size=(1, 2, 4),
1823        fullname='interpolate_nearest_1d',
1824        pickle=False,
1825        default_dtype=torch.double,
1826    ),
1827    dict(
1828        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
1829        cpp_options_args='''F::InterpolateFuncOptions()
1830                            .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
1831        input_size=(0, 2, 4),
1832        fullname='interpolate_nearest_1d_zero_dim',
1833        pickle=False,
1834    ),
1835    dict(
1836        constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
1837        cpp_options_args='''F::InterpolateFuncOptions()
1838                            .size(std::vector<int64_t>({12})).scale_factor(std::nullopt).mode(torch::kNearest)''',
1839        input_size=(1, 2, 3),
1840        fullname='interpolate_nearest_tuple_1d',
1841        pickle=False,
1842        default_dtype=torch.double,
1843    ),
1844    dict(
1845        constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
1846        cpp_options_args='''F::InterpolateFuncOptions()
1847                            .size(std::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)''',
1848        input_size=(1, 2, 4),
1849        fullname='interpolate_nearest_scale_1d',
1850        pickle=False,
1851        default_dtype=torch.double,
1852    ),
1853    dict(
1854        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
1855        cpp_options_args='''F::InterpolateFuncOptions()
1856                            .size(std::vector<int64_t>({12}))
1857                            .scale_factor(std::nullopt)
1858                            .mode(torch::kLinear)
1859                            .align_corners(false)''',
1860        input_size=(1, 2, 4),
1861        fullname='interpolate_linear_1d',
1862        pickle=False,
1863        default_dtype=torch.double,
1864    ),
1865    dict(
1866        constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
1867        cpp_options_args='''F::InterpolateFuncOptions()
1868                            .size(std::vector<int64_t>({4}))
1869                            .scale_factor(std::nullopt)
1870                            .mode(torch::kLinear)
1871                            .align_corners(false)''',
1872        input_size=(1, 2, 3),
1873        fullname='interpolate_linear_tuple_1d',
1874        pickle=False,
1875        default_dtype=torch.double,
1876    ),
1877    dict(
1878        constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
1879        cpp_options_args='''F::InterpolateFuncOptions()
1880                            .size(std::nullopt)
1881                            .scale_factor(std::vector<double>({4.}))
1882                            .mode(torch::kLinear)
1883                            .align_corners(false)''',
1884        input_size=(1, 2, 4),
1885        fullname='interpolate_linear_scale_1d',
1886        pickle=False,
1887        default_dtype=torch.double,
1888    ),
1889    dict(
1890        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
1891        cpp_options_args='''F::InterpolateFuncOptions()
1892                            .size(std::vector<int64_t>({12}))
1893                            .scale_factor(std::nullopt)
1894                            .mode(torch::kLinear)
1895                            .align_corners(false)''',
1896        input_size=(0, 2, 4),
1897        fullname='interpolate_linear_1d_zero_dim',
1898        pickle=False,
1899    ),
1900    dict(
1901        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
1902        cpp_options_args='''F::InterpolateFuncOptions()
1903                            .size(std::vector<int64_t>({12}))
1904                            .scale_factor(std::nullopt)
1905                            .mode(torch::kLinear)
1906                            .align_corners(true)''',
1907        input_size=(1, 2, 4),
1908        fullname='interpolate_linear_1d_align_corners',
1909        pickle=False,
1910        default_dtype=torch.double,
1911    ),
1912    dict(
1913        constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
1914        cpp_options_args='''F::InterpolateFuncOptions()
1915                            .size(std::nullopt)
1916                            .scale_factor(std::vector<double>({4.}))
1917                            .mode(torch::kLinear)
1918                            .align_corners(true)''',
1919        input_size=(1, 2, 4),
1920        fullname='interpolate_linear_scale_1d_align_corners',
1921        pickle=False,
1922        default_dtype=torch.double,
1923    ),
1924    dict(
1925        constructor=wrap_functional(F.interpolate, size=2, scale_factor=None, mode='nearest'),
1926        cpp_options_args='''F::InterpolateFuncOptions()
1927                            .size(std::vector<int64_t>({2, 2}))
1928                            .scale_factor(std::nullopt)
1929                            .mode(torch::kNearest)''',
1930        input_size=(1, 128, 1, 1),
1931        fullname='interpolate_nearest_2d_launch_configs',
1932        pickle=False,
1933        default_dtype=torch.double,
1934    ),
1935    dict(
1936        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
1937        cpp_options_args='''F::InterpolateFuncOptions()
1938                            .size(std::vector<int64_t>({12, 12}))
1939                            .scale_factor(std::nullopt)
1940                            .mode(torch::kNearest)''',
1941        input_size=(1, 2, 4, 4),
1942        fullname='interpolate_nearest_2d',
1943        pickle=False,
1944        default_dtype=torch.double,
1945    ),
1946    dict(
1947        constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
1948        cpp_options_args='''F::InterpolateFuncOptions()
1949                            .size(std::vector<int64_t>({12, 16}))
1950                            .scale_factor(std::nullopt)
1951                            .mode(torch::kNearest)''',
1952        input_size=(1, 2, 3, 4),
1953        fullname='interpolate_nearest_tuple_2d',
1954        pickle=False,
1955        default_dtype=torch.double,
1956    ),
1957    dict(
1958        constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
1959        cpp_options_args='''F::InterpolateFuncOptions()
1960                            .size(std::nullopt)
1961                            .scale_factor(std::vector<double>({4., 4.}))
1962                            .mode(torch::kNearest)''',
1963        input_size=(1, 2, 4, 4),
1964        fullname='interpolate_nearest_scale_2d',
1965        pickle=False,
1966        default_dtype=torch.double,
1967    ),
1968    dict(
1969        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
1970        cpp_options_args='''F::InterpolateFuncOptions()
1971                            .size(std::vector<int64_t>({12, 12}))
1972                            .scale_factor(std::nullopt)
1973                            .mode(torch::kNearest)''',
1974        input_size=(0, 2, 4, 4),
1975        fullname='interpolate_nearest_2d_zero_dim',
1976        pickle=False,
1977    ),
1978    dict(
1979        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
1980        cpp_options_args='''F::InterpolateFuncOptions()
1981                            .size(std::vector<int64_t>({12, 12}))
1982                            .scale_factor(std::nullopt)
1983                            .mode(torch::kBilinear)
1984                            .align_corners(false)''',
1985        input_size=(1, 2, 4, 4),
1986        fullname='interpolate_bilinear_2d',
1987        pickle=False,
1988        default_dtype=torch.double,
1989    ),
1990    dict(
1991        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
1992        cpp_options_args='''F::InterpolateFuncOptions()
1993                            .size(std::vector<int64_t>({12, 12}))
1994                            .scale_factor(std::nullopt)
1995                            .mode(torch::kBilinear)
1996                            .align_corners(false)''',
1997        input_size=(0, 2, 4, 4),
1998        fullname='interpolate_bilinear_2d_zero_dim',
1999        pickle=False,
2000    ),
2001    dict(
2002        constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
2003                                    mode='bilinear', align_corners=False),
2004        cpp_options_args='''F::InterpolateFuncOptions()
2005                            .size(std::vector<int64_t>({4, 6}))
2006                            .scale_factor(std::nullopt)
2007                            .mode(torch::kBilinear)
2008                            .align_corners(false)''',
2009        input_size=(1, 2, 2, 3),
2010        fullname='interpolate_bilinear_tuple_2d',
2011        pickle=False,
2012        default_dtype=torch.double,
2013    ),
2014    dict(
2015        constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
2016                                    mode='bilinear', align_corners=False),
2017        cpp_options_args='''F::InterpolateFuncOptions()
2018                            .size(std::nullopt)
2019                            .scale_factor(std::vector<double>({4., 4.}))
2020                            .mode(torch::kBilinear)
2021                            .align_corners(false)''',
2022        input_size=(1, 2, 4, 4),
2023        fullname='interpolate_bilinear_scale_2d',
2024        pickle=False,
2025        default_dtype=torch.double,
2026    ),
2027    dict(
2028        constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
2029                                    mode='bilinear', align_corners=False),
2030        cpp_options_args='''F::InterpolateFuncOptions()
2031                            .size(std::nullopt)
2032                            .scale_factor(std::vector<double>({2., 2.}))
2033                            .mode(torch::kBilinear)
2034                            .align_corners(false)''',
2035        input_size=(1, 2, 4, 4),
2036        fullname='interpolate_bilinear_scale_tuple_shared_2d',
2037        pickle=False,
2038        default_dtype=torch.double,
2039    ),
2040    dict(
2041        constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
2042                                    mode='bilinear', align_corners=False),
2043        cpp_options_args='''F::InterpolateFuncOptions()
2044                            .size(std::nullopt)
2045                            .scale_factor(std::vector<double>({2., 1.}))
2046                            .mode(torch::kBilinear)
2047                            .align_corners(false)''',
2048        input_size=(1, 2, 4, 4),
2049        fullname='interpolate_bilinear_scale_tuple_skewed_2d',
2050        pickle=False,
2051        default_dtype=torch.double,
2052    ),
2053    dict(
2054        constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
2055        cpp_options_args='''F::InterpolateFuncOptions()
2056                            .size(std::vector<int64_t>({4, 6}))
2057                            .scale_factor(std::nullopt)
2058                            .mode(torch::kBilinear)
2059                            .align_corners(true)''',
2060        input_size=(1, 2, 4, 4),
2061        fullname='interpolate_bilinear_tuple_2d_align_corners',
2062        pickle=False,
2063        default_dtype=torch.double,
2064    ),
2065    dict(
2066        constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
2067                                    mode='bilinear', align_corners=True),
2068        cpp_options_args='''F::InterpolateFuncOptions()
2069                            .size(std::nullopt)
2070                            .scale_factor(std::vector<double>({2., 1.}))
2071                            .mode(torch::kBilinear)
2072                            .align_corners(true)''',
2073        input_size=(1, 2, 4, 4),
2074        fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
2075        pickle=False,
2076        default_dtype=torch.double,
2077    ),
2078    dict(
2079        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
2080        cpp_options_args='''F::InterpolateFuncOptions()
2081                            .size(std::vector<int64_t>({12, 12}))
2082                            .scale_factor(std::nullopt)
2083                            .mode(torch::kBicubic)
2084                            .align_corners(false)''',
2085        input_size=(1, 2, 4, 4),
2086        fullname='interpolate_bicubic_2d',
2087        pickle=False,
2088        default_dtype=torch.double,
2089    ),
2090    dict(
2091        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
2092        cpp_options_args='''F::InterpolateFuncOptions()
2093                            .size(std::vector<int64_t>({12, 12}))
2094                            .scale_factor(std::nullopt)
2095                            .mode(torch::kBicubic)
2096                            .align_corners(false)''',
2097        input_size=(0, 2, 4, 4),
2098        fullname='interpolate_bicubic_2d_zero_dim',
2099        pickle=False,
2100    ),
2101    dict(
2102        constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
2103                                    mode='bicubic', align_corners=False),
2104        cpp_options_args='''F::InterpolateFuncOptions()
2105                            .size(std::vector<int64_t>({4, 6}))
2106                            .scale_factor(std::nullopt)
2107                            .mode(torch::kBicubic)
2108                            .align_corners(false)''',
2109        input_size=(1, 2, 2, 3),
2110        fullname='interpolate_bicubic_tuple_2d',
2111        pickle=False,
2112        default_dtype=torch.double,
2113    ),
2114    dict(
2115        constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
2116        cpp_options_args='''F::InterpolateFuncOptions()
2117                            .size(std::nullopt)
2118                            .scale_factor(std::vector<double>({4., 4.}))
2119                            .mode(torch::kBicubic)
2120                            .align_corners(false)''',
2121        input_size=(1, 2, 4, 4),
2122        fullname='interpolate_bicubic_scale_2d',
2123        pickle=False,
2124        default_dtype=torch.double,
2125    ),
2126    dict(
2127        constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
2128                                    mode='bicubic', align_corners=False),
2129        cpp_options_args='''F::InterpolateFuncOptions()
2130                            .size(std::nullopt)
2131                            .scale_factor(std::vector<double>({2., 2.}))
2132                            .mode(torch::kBicubic)
2133                            .align_corners(false)''',
2134        input_size=(1, 2, 4, 4),
2135        fullname='interpolate_bicubic_scale_tuple_shared_2d',
2136        pickle=False,
2137        default_dtype=torch.double,
2138    ),
2139    dict(
2140        constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
2141                                    mode='bicubic', align_corners=False),
2142        cpp_options_args='''F::InterpolateFuncOptions()
2143                            .size(std::nullopt)
2144                            .scale_factor(std::vector<double>({2., 1.}))
2145                            .mode(torch::kBicubic)
2146                            .align_corners(false)''',
2147        input_size=(1, 2, 4, 4),
2148        fullname='interpolate_bicubic_scale_tuple_skewed_2d',
2149        pickle=False,
2150        default_dtype=torch.double,
2151    ),
2152    dict(
2153        constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
2154        cpp_options_args='''F::InterpolateFuncOptions()
2155                            .size(std::vector<int64_t>({4, 6}))
2156                            .scale_factor(std::nullopt)
2157                            .mode(torch::kBicubic)
2158                            .align_corners(true)''',
2159        input_size=(1, 2, 4, 4),
2160        fullname='interpolate_bicubic_tuple_2d_align_corners',
2161        pickle=False,
2162        default_dtype=torch.double,
2163    ),
2164    dict(
2165        constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
2166                                    mode='bicubic', align_corners=True),
2167        cpp_options_args='''F::InterpolateFuncOptions()
2168                            .size(std::nullopt)
2169                            .scale_factor(std::vector<double>({2., 1.}))
2170                            .mode(torch::kBicubic)
2171                            .align_corners(true)''',
2172        input_size=(1, 2, 4, 4),
2173        fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
2174        pickle=False,
2175        default_dtype=torch.double,
2176    ),
2177    dict(
2178        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
2179        cpp_options_args='''F::InterpolateFuncOptions()
2180                            .size(std::vector<int64_t>({12, 12, 12}))
2181                            .scale_factor(std::nullopt)
2182                            .mode(torch::kNearest)''',
2183        input_size=(1, 2, 4, 4, 4),
2184        fullname='interpolate_nearest_3d',
2185        pickle=False,
2186        default_dtype=torch.double,
2187    ),
2188    dict(
2189        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
2190        cpp_options_args='''F::InterpolateFuncOptions()
2191                            .size(std::vector<int64_t>({12, 12, 12}))
2192                            .scale_factor(std::nullopt)
2193                            .mode(torch::kNearest)''',
2194        input_size=(0, 2, 4, 4, 4),
2195        fullname='interpolate_nearest_3d_zero_dim',
2196        pickle=False,
2197    ),
2198    dict(
2199        constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
2200        cpp_options_args='''F::InterpolateFuncOptions()
2201                            .size(std::vector<int64_t>({12, 16, 16}))
2202                            .scale_factor(std::nullopt)
2203                            .mode(torch::kNearest)''',
2204        input_size=(1, 2, 3, 4, 4),
2205        fullname='interpolate_nearest_tuple_3d',
2206        pickle=False,
2207        default_dtype=torch.double,
2208    ),
2209    dict(
2210        constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
2211        cpp_options_args='''F::InterpolateFuncOptions()
2212                            .size(std::nullopt)
2213                            .scale_factor(std::vector<double>({4., 4., 4.}))
2214                            .mode(torch::kNearest)''',
2215        input_size=(1, 2, 4, 4, 4),
2216        fullname='interpolate_nearest_scale_3d',
2217        pickle=False,
2218        default_dtype=torch.double,
2219    ),
2220    dict(
2221        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
2222        cpp_options_args='''F::InterpolateFuncOptions()
2223                            .size(std::vector<int64_t>({12, 12, 12}))
2224                            .scale_factor(std::nullopt)
2225                            .mode(torch::kTrilinear)
2226                            .align_corners(false)''',
2227        input_size=(1, 2, 4, 4, 4),
2228        fullname='interpolate_trilinear_3d',
2229        pickle=False,
2230        default_dtype=torch.double,
2231    ),
2232    dict(
2233        constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
2234        cpp_options_args='''F::InterpolateFuncOptions()
2235                            .size(std::vector<int64_t>({12, 12, 12}))
2236                            .scale_factor(std::nullopt)
2237                            .mode(torch::kTrilinear)
2238                            .align_corners(false)''',
2239        input_size=(0, 2, 4, 4, 4),
2240        fullname='interpolate_trilinear_3d_zero_dim',
2241        pickle=False,
2242    ),
2243    dict(
2244        constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
2245                                    scale_factor=None, mode='trilinear', align_corners=False),
2246        cpp_options_args='''F::InterpolateFuncOptions()
2247                            .size(std::vector<int64_t>({4, 6, 6}))
2248                            .scale_factor(std::nullopt)
2249                            .mode(torch::kTrilinear)
2250                            .align_corners(false)''',
2251        input_size=(1, 2, 2, 3, 3),
2252        fullname='interpolate_trilinear_tuple_3d',
2253        pickle=False,
2254        default_dtype=torch.double,
2255    ),
2256    dict(
2257        constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
2258        cpp_options_args='''F::InterpolateFuncOptions()
2259                            .size(std::nullopt)
2260                            .scale_factor(std::vector<double>({3., 3., 3.}))
2261                            .mode(torch::kTrilinear)
2262                            .align_corners(false)''',
2263        input_size=(1, 2, 3, 4, 5),
2264        fullname='interpolate_trilinear_scale_3d',
2265        # See https://github.com/pytorch/pytorch/issues/5006
2266        precision=3e-4,
2267        pickle=False,
2268        default_dtype=torch.double,
2269    ),
2270    dict(
2271        constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
2272                                    mode='trilinear', align_corners=True),
2273        cpp_options_args='''F::InterpolateFuncOptions()
2274                            .size(std::vector<int64_t>({4, 6, 6}))
2275                            .scale_factor(std::nullopt)
2276                            .mode(torch::kTrilinear)
2277                            .align_corners(true)''',
2278        input_size=(1, 2, 2, 3, 3),
2279        fullname='interpolate_trilinear_tuple_3d_align_corners',
2280        pickle=False,
2281        default_dtype=torch.double
2282    ),
2283    dict(
2284        constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
2285        cpp_options_args='''F::InterpolateFuncOptions()
2286                            .size(std::nullopt)
2287                            .scale_factor(std::vector<double>({3., 3., 3.}))
2288                            .mode(torch::kTrilinear)
2289                            .align_corners(true)''',
2290        input_size=(1, 2, 3, 4, 4),
2291        fullname='interpolate_trilinear_scale_3d_align_corners',
2292        # See https://github.com/pytorch/pytorch/issues/5006
2293        precision=3e-4,
2294        pickle=False,
2295        default_dtype=torch.double,
2296    ),
2297    dict(
2298        constructor=wrap_functional(F.softmax, dim=-1),
2299        cpp_options_args='F::SoftmaxFuncOptions(-1)',
2300        input_size=(2, 128),  # trigger the last-dim algo in CUDA
2301        fullname='softmax_lastdim',
2302        pickle=False,
2303        default_dtype=torch.double,
2304    ),
2305    dict(
2306        constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
2307        cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
2308        input_size=(2, 128),
2309        fullname='softmax_lastdim_dtype',
2310        pickle=False,
2311        test_cuda=False,
2312        default_dtype=torch.double,
2313    ),
2314    dict(
2315        constructor=wrap_functional(F.softmax, dim=1),
2316        cpp_options_args='F::SoftmaxFuncOptions(1)',
2317        input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
2318        fullname='softmax_spatial_special',
2319        pickle=False,
2320        default_dtype=torch.double,
2321    ),
2322    dict(
2323        constructor=wrap_functional(F.softmax, dim=1),
2324        cpp_options_args='F::SoftmaxFuncOptions(1)',
2325        input_size=(2, 2, 4, 4),  # regular spatial algorithm
2326        fullname='softmax_spatial',
2327        pickle=False,
2328        default_dtype=torch.double,
2329    ),
2330    dict(
2331        constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
2332        cpp_options_args='F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)',
2333        input_size=(2, 2, 4, 4),  # regular spatial algorithm
2334        fullname='softmax_spatial_dtype',
2335        pickle=False,
2336        test_cuda=False,
2337        default_dtype=torch.double,
2338    ),
2339    dict(
2340        constructor=wrap_functional(F.softmax, dim=0),
2341        cpp_options_args='F::SoftmaxFuncOptions(0)',
2342        input_size=(2, 3, 4, 5),
2343        fullname='softmax_functional_dim0',
2344        test_cuda=False,
2345        pickle=False,
2346        default_dtype=torch.double,
2347    ),
2348    dict(
2349        constructor=wrap_functional(F.softmax, dim=3),
2350        cpp_options_args='F::SoftmaxFuncOptions(3)',
2351        input_size=(2, 3, 4, 5),
2352        fullname='softmax_functional_dim3',
2353        test_cuda=False,
2354        pickle=False,
2355        default_dtype=torch.double,
2356    ),
2357    dict(
2358        constructor=wrap_functional(F.softmax, dim=-1),
2359        cpp_options_args='F::SoftmaxFuncOptions(-1)',
2360        input_size=(),
2361        fullname='softmax_functional_scalar',
2362        test_cuda=False,
2363        pickle=False,
2364    ),
2365    dict(
2366        constructor=wrap_functional(F.log_softmax, dim=-1),
2367        cpp_options_args='F::LogSoftmaxFuncOptions(-1)',
2368        input_size=(2, 128),  # trigger the last-dim algo in CUDA
2369        fullname='log_softmax_lastdim',
2370        pickle=False,
2371        default_dtype=torch.double,
2372    ),
2373    dict(
2374        constructor=wrap_functional(F.log_softmax, dim=1),
2375        cpp_options_args='F::LogSoftmaxFuncOptions(1)',
2376        input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
2377        fullname='log_softmax_spatial_special',
2378        pickle=False,
2379        default_dtype=torch.double,
2380    ),
2381    dict(
2382        constructor=wrap_functional(F.log_softmax, dim=1),
2383        cpp_options_args='F::LogSoftmaxFuncOptions(1)',
2384        input_size=(2, 2, 4, 4),  # regular spatial algorithm
2385        fullname='log_softmax_spatial',
2386        pickle=False,
2387        default_dtype=torch.double,
2388    ),
2389    dict(
2390        constructor=wrap_functional(F.log_softmax, dim=0),
2391        cpp_options_args='F::LogSoftmaxFuncOptions(0)',
2392        input_size=(2, 3, 4, 5),
2393        fullname='log_softmax_dim0',
2394        pickle=False,
2395        default_dtype=torch.double,
2396    ),
2397    dict(
2398        constructor=wrap_functional(F.log_softmax, dim=3),
2399        cpp_options_args='F::LogSoftmaxFuncOptions(3)',
2400        input_size=(2, 3, 4, 5),
2401        fullname='log_softmax_dim3',
2402        pickle=False,
2403        default_dtype=torch.double,
2404    ),
2405    dict(
2406        constructor=wrap_functional(F.log_softmax, dim=0),
2407        cpp_options_args='F::LogSoftmaxFuncOptions(0)',
2408        input_size=(),
2409        fullname='log_softmax_scalar',
2410        pickle=False,
2411    ),
2412    dict(
2413        fullname='Unfold',
2414        constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
2415        cpp_constructor_args='torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
2416        input_size=(2, 4, 3, 3),
2417        check_gradgrad=False,
2418        test_cuda=True,
2419        default_dtype=torch.double,
2420    ),
2421    dict(
2422        fullname='Fold',
2423        constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
2424        cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
2425        input_size=(2, 16, 4),
2426        check_gradgrad=False,
2427        test_cuda=True,
2428        default_dtype=torch.double,
2429    ),
2430    dict(
2431        fullname='Fold_no_batch_dim_input',
2432        constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
2433        cpp_constructor_args='torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})',
2434        input_size=(16, 4),
2435        check_gradgrad=False,
2436        ref=single_batch_reference_fn,
2437        test_cuda=True,
2438        default_dtype=torch.double,
2439    ),
2440    dict(
2441        fullname='Unfold_int_input',
2442        constructor=lambda: nn.Unfold(2, 1, 0, 1),
2443        cpp_constructor_args='torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)',
2444        input_size=(2, 4, 3, 3),
2445        check_gradgrad=False,
2446        test_cuda=True,
2447        default_dtype=torch.double,
2448    ),
2449    dict(
2450        fullname='Fold_int_input',
2451        constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
2452        cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
2453        input_size=(2, 16, 4),
2454        check_gradgrad=False,
2455        test_cuda=True,
2456        default_dtype=torch.double,
2457    ),
2458    dict(
2459        fullname='Fold_no_batch_dim_int_input',
2460        constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
2461        cpp_constructor_args='torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)',
2462        input_size=(16, 4),
2463        ref=single_batch_reference_fn,
2464        check_gradgrad=False,
2465        test_cuda=True,
2466        default_dtype=torch.double,
2467    ),
2468    dict(
2469        module_name='RReLU',
2470        constructor_args=(0.1, 0.9),
2471        cpp_constructor_args='torch::nn::RReLUOptions().lower(0.1).upper(0.9)',
2472        input_size=(),
2473        desc='with_up_down_scalar',
2474        test_cuda=False,
2475        default_dtype=torch.double,
2476    ),
2477    dict(
2478        module_name='PairwiseDistance',
2479        input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
2480        default_dtype=torch.double,
2481    ),
2482    dict(
2483        module_name='PairwiseDistance',
2484        input_fn=lambda: (torch.randn(10, 1), torch.randn(10, 8)),
2485        desc='broadcast_lhs',
2486        default_dtype=torch.double,
2487    ),
2488    dict(
2489        module_name='PairwiseDistance',
2490        input_fn=lambda: (torch.randn(10, 8), torch.randn(1, 8)),
2491        desc='broadcast_rhs',
2492        default_dtype=torch.double,
2493    ),
2494    dict(
2495        module_name='PairwiseDistance',
2496        constructor_args=(1.5, 1e-05, True),
2497        cpp_constructor_args='torch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)',
2498        input_fn=lambda: (torch.randn(10, 8), torch.randn(10, 8)),
2499        desc='with_non_default_args',
2500        default_dtype=torch.double,
2501    ),
2502    dict(
2503        module_name='PairwiseDistance',
2504        input_fn=lambda: (torch.randn(8), torch.randn(8)),
2505        reference_fn=single_batch_reference_fn,
2506        desc='no_batch_dim',
2507        default_dtype=torch.double,
2508    ),
2509    dict(
2510        module_name='TransformerEncoderLayer',
2511        constructor_args=(4, 2, 16, 0.0),
2512        cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
2513                                .dim_feedforward(16)
2514                                .dropout(0.0)''',
2515        input_size=(2, 3, 4),
2516        desc='relu_activation',
2517        with_tf32=True,
2518        tf32_precision=0.1,
2519        # TODO(#50743): figure out the error
2520        # RuntimeError: The size of tensor a (6) must match the size of tensor b (4)
2521        # at non-singleton dimension 2
2522        check_batched_grad=False,
2523        check_gradgrad=False,
2524        default_dtype=torch.double,
2525    ),
2526    dict(
2527        module_name='TransformerEncoderLayer',
2528        constructor_args=(4, 2, 8, 0.0, F.gelu),
2529        cpp_constructor_args='''torch::nn::TransformerEncoderLayerOptions(4, 2)
2530                                .dim_feedforward(8)
2531                                .dropout(0.0)
2532                                .activation(torch::kGELU)''',
2533        input_size=(2, 3, 4),
2534        check_gradgrad=False,
2535        desc='gelu_activation',
2536        with_tf32=True,
2537        tf32_precision=0.08 if SM90OrLater else 0.05,
2538        default_dtype=torch.double,
2539    ),
2540    dict(
2541        module_name='TransformerDecoderLayer',
2542        constructor_args=(4, 2, 8, 0.0),
2543        cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
2544                                .dim_feedforward(8)
2545                                .dropout(0.0)''',
2546        input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
2547        check_gradgrad=False,
2548        desc='relu_activation',
2549        with_tf32=True,
2550        tf32_precision=0.05,
2551        default_dtype=torch.double,
2552    ),
2553    dict(
2554        module_name='TransformerDecoderLayer',
2555        constructor_args=(4, 2, 8, 0.0, F.gelu),
2556        cpp_constructor_args='''torch::nn::TransformerDecoderLayerOptions(4, 2)
2557                                .dim_feedforward(8)
2558                                .dropout(0.0)
2559                                .activation(torch::kGELU)''',
2560        input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4)),
2561        check_gradgrad=False,
2562        desc='gelu_activation',
2563        with_tf32=True,
2564        tf32_precision=0.05,
2565        default_dtype=torch.double,
2566    ),
2567    dict(
2568        module_name='Transformer',
2569        constructor_args=(4, 2, 2, 2, 8, 0.0, F.relu),
2570        cpp_constructor_args='''torch::nn::TransformerOptions()
2571                                .d_model(4)
2572                                .nhead(2)
2573                                .num_encoder_layers(2)
2574                                .num_decoder_layers(2)
2575                                .dim_feedforward(8)
2576                                .dropout(0.0)
2577                                .activation(torch::kReLU)''',
2578        input_fn=lambda: (torch.rand(3, 3, 4), torch.rand(2, 3, 4), torch.rand(3, 3)),
2579        check_gradgrad=False,
2580        desc='multilayer_coder',
2581        with_tf32=True,
2582        tf32_precision=0.05 if SM90OrLater else 0.03,
2583        default_dtype=torch.double,
2584    ),
2585    dict(
2586        module_name='Linear',
2587        constructor_args=(3, 5),
2588        cpp_constructor_args='torch::nn::LinearOptions(3, 5)',
2589        input_fn=lambda: torch.rand(3),
2590        reference_fn=lambda i, p, _: torch.mm(i.view(1, -1), p[0].t()).view(-1) + p[1],
2591        desc="no_batch_dim",
2592        with_tf32=True,
2593        tf32_precision=0.005,
2594        default_dtype=torch.double,
2595    ),
2596    dict(
2597        module_name='Flatten',
2598        cpp_constructor_args='torch::nn::FlattenOptions().start_dim(-3).end_dim(-1)',
2599        constructor_args=(-3, -1),
2600        input_size=(3, 4, 5),
2601        reference_fn=single_batch_reference_fn,
2602        desc="no_batch_dim",
2603        default_dtype=torch.double,
2604    ),
2605    dict(
2606        module_name='Unflatten',
2607        cpp_constructor_args='torch::nn::UnflattenOptions(-2, {2, 2})',
2608        constructor_args=(-2, torch.Size([2, 2])),
2609        input_size=(3, 4, 5),
2610        reference_fn=single_batch_reference_fn,
2611        desc="no_batch_dim",
2612        default_dtype=torch.double,
2613    ),
2614    dict(
2615        module_name='LayerNorm',
2616        constructor_args=([56, 56, 56], 1e-5, False),
2617        cpp_constructor_args='torch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false)',
2618        input_size=(4, 56, 56, 56),
2619        cudnn=True,
2620        check_eval=True,
2621        gradcheck_fast_mode=True,
2622        check_half=True,
2623        desc='3d_no_affine_large_feature',
2624    ),
2625]
2626
2627# add conv padding mode tests:
2628for padding_mode, cpp_padding_mode in zip(
2629        ['reflect', 'circular', 'replicate', 'zeros'],
2630        ['torch::kReflect', 'torch::kCircular', 'torch::kReplicate', 'torch::kZeros']):
2631    # conv signature:
2632    #     in_channels, out_channels, kernel_size, stride=1,
2633    #     padding=0, dilation=1, groups=1,
2634    #     bias=True, padding_mode='zeros'
2635    for d in (1, 2, 3):
2636        if d == 3 and padding_mode == 'reflect':
2637            # FIXME: remove after implementing reflection pad 3d
2638            #        https://github.com/pytorch/pytorch/issues/27655
2639            continue
2640        padding = tuple(range(1, d + 1))
2641        cpp_padding = '{' + ', '.join(map(str, padding)) + '}'
2642        input_size = (2, 2) + (4,) * d
2643        output_size = (2, 3) + tuple(p + 1 for p in padding)  # simplified from `(4 + 2 * p - 3) // 2 + 1`
2644        new_module_tests.append(
2645            dict(
2646                module_name=f'Conv{d}d',
2647                constructor_args=(2, 3, 3, 2, padding, 1, 1, True, padding_mode),
2648                cpp_constructor_args=f'''torch::nn::Conv{d}dOptions(2, 3, 3)
2649                                        .stride(2)
2650                                        .padding({cpp_padding})
2651                                        .dilation(1)
2652                                        .groups(1)
2653                                        .bias(true)
2654                                        .padding_mode({cpp_padding_mode})''',
2655                input_size=input_size,
2656                output_size=output_size,
2657                cudnn=True,
2658                desc=f'{padding_mode}_stride2_pad2',
2659                with_tf32=True,
2660                tf32_precision=0.05,
2661                default_dtype=torch.double,
2662            ),
2663        )
2664
2665# Check that non linear activations work with no batch dimensions
2666non_linear_activations_no_batch = [
2667    'ELU', 'Hardshrink', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'LeakyReLU',
2668    'LogSigmoid', 'PReLU', 'ReLU', 'ReLU6', 'RReLU', 'SELU', 'CELU', 'GELU', 'GLU',
2669    'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
2670    'Tanhshrink', 'Threshold'
2671]
2672non_linear_activations_extra_info: Dict[str, dict] = {
2673    'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
2674    'Threshold': {'constructor_args': (2., 1.)},
2675    'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
2676    'Hardswish': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
2677    # For RRelu, test that compare CPU and GPU results fail because RNG
2678    # is different between CPU and GPU
2679    'RReLU': {'test_cuda': False, 'default_dtype': torch.double},
2680    'ELU': {'default_dtype': torch.double},
2681    'GELU': {'default_dtype': torch.double},
2682    'GLU': {'default_dtype': torch.double},
2683    'Hardshrink': {'default_dtype': torch.double},
2684    'Hardtanh': {'default_dtype': torch.double},
2685    'LeakyReLU': {'default_dtype': torch.double},
2686    'LogSigmoid': {'default_dtype': torch.double},
2687    'Mish': {'default_dtype': torch.double},
2688    'PReLU': {'default_dtype': torch.double},
2689    'ReLU6': {'default_dtype': torch.double},
2690    'ReLU': {'default_dtype': torch.double},
2691    'SELU': {'default_dtype': torch.double},
2692    'SiLU': {'default_dtype': torch.double},
2693    'Sigmoid': {'default_dtype': torch.double},
2694    'Softplus': {'default_dtype': torch.double},
2695    'Softshrink': {'default_dtype': torch.double},
2696    'Softsign': {'default_dtype': torch.double},
2697    'Tanh': {'default_dtype': torch.double},
2698    'Tanhshrink': {'default_dtype': torch.double},
2699}
2700for non_linear_activation in non_linear_activations_no_batch:
2701    activation_test_info = dict(
2702        module_name=non_linear_activation,
2703        input_size=(4,),
2704        reference_fn=single_batch_reference_fn,
2705        desc='no_batch_dim',
2706        test_cpp_api_parity=False,
2707    )
2708    extra_info = non_linear_activations_extra_info.get(non_linear_activation, {})
2709    activation_test_info.update(extra_info)
2710    new_module_tests.append(activation_test_info)
2711
2712
2713def kldivloss_reference(input, target, reduction='mean', log_target=False):
2714    if log_target:
2715        result = torch.exp(target) * (target - input)
2716    else:
2717        result = target * (target.log() - input)
2718    if reduction == 'mean':
2719        return result.mean()
2720    elif reduction == 'sum':
2721        return result.sum()
2722    elif reduction == 'batchmean' and result.dim() != 0:
2723        return result.sum() / result.size(0)
2724    return result
2725
2726
2727def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
2728                        reduction='mean'):
2729    assert input.dim() >= 3
2730    N = input.size(0)
2731    C = input.size(1)
2732    out_size = (N,) + input.size()[2:]
2733    output = torch.zeros(out_size).type_as(input)
2734
2735    if weight is None:
2736        weight = torch.ones(C).type_as(input)
2737    total_weight = 0
2738    for tup in product(*[range(size) for size in out_size]):
2739        t_nx = target[tup]
2740        norm = 0. if ignore_index == t_nx else weight[t_nx].item()
2741        input_index = list(tup)
2742        input_index.insert(1, t_nx)
2743        output[tup] = -input[tuple(input_index)] * norm
2744        total_weight += norm
2745
2746    if reduction == 'mean':
2747        return output.sum() / total_weight
2748    elif reduction == 'sum':
2749        return output.sum()
2750    return output
2751
2752
2753def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean',
2754                                             label_smoothing=0.0):
2755    assert input.dim() >= 2
2756
2757    input = torch.log_softmax(input, 1)
2758    C = input.size(1)
2759    if weight is None:
2760        weight = torch.ones(C).type_as(input)
2761    weight = weight.view(1, C, *(1 for _ in input.shape[2:]))
2762
2763    if label_smoothing > 0.0:
2764        assert label_smoothing <= 1.0
2765        target = (target * (1 - label_smoothing) + label_smoothing / C)
2766
2767    output = -(input * target * weight).sum(dim=1)
2768    if reduction == 'mean':
2769        return output.mean()
2770    elif reduction == 'sum':
2771        return output.sum()
2772    return output
2773
2774
2775def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100,
2776                                                reduction='mean', label_smoothing=0.0):
2777    log_softmax_input = torch.log_softmax(input, 1)
2778    nllloss = F.nll_loss(
2779        log_softmax_input,
2780        target,
2781        weight,
2782        ignore_index=ignore_index,
2783        reduction=reduction)
2784
2785    if label_smoothing == 0.0:
2786        return nllloss
2787
2788    assert 0.0 < label_smoothing <= 1.0
2789
2790    input = torch.log_softmax(input, 1)
2791    C = input.size(1)
2792    if weight is not None:
2793        input = input * weight.view(1, C, *(1 for _ in input.shape[2:]))
2794
2795    smooth_loss = -torch.sum(input, 1)
2796
2797    ignore_mask = target == ignore_index
2798    smooth_loss.masked_fill_(ignore_mask, 0.0)
2799
2800    if reduction == 'mean':
2801        if weight is not None:
2802            # TODO: This code can path can be removed if #61309 is resolved
2803            # loss is normalized by the weights to be consistent with nll_loss_nd
2804            ret = torch.sum(smooth_loss) / weight.gather(0, target.masked_select(ignore_mask.logical_not()).flatten()).sum()
2805        else:
2806            ret = torch.mean(smooth_loss.masked_select(ignore_mask.logical_not()))
2807    elif reduction == 'sum':
2808        ret = torch.sum(smooth_loss)
2809    else:
2810        ret = smooth_loss
2811
2812    return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C)
2813
2814
2815def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean',
2816                                 label_smoothing=0.0):
2817    if input.shape == target.shape:
2818        return cross_entropy_loss_prob_target_reference(
2819            input,
2820            target,
2821            weight=weight,
2822            reduction=reduction,
2823            label_smoothing=label_smoothing)
2824    else:
2825        return cross_entropy_loss_indices_target_reference(
2826            input, target, weight=weight, reduction=reduction,
2827            ignore_index=ignore_index, label_smoothing=label_smoothing
2828        )
2829
2830
2831def nllloss_reference(input, target, weight=None, ignore_index=-100,
2832                      reduction='mean'):
2833
2834    def nll_loss_helper(input, target, weight, ignore_index):
2835        if target == ignore_index:
2836            return (0, 0)
2837        norm = 1 if weight is None else weight[target]
2838        result = -input[target] * norm
2839        return (result, norm)
2840
2841    losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
2842                          for i, t in zip(input, target)]
2843    losses, weights = zip(*losses_and_weights)
2844    losses_tensor = input.new_tensor(losses)
2845    if reduction == 'mean':
2846        return sum(losses_tensor) / sum(weights)
2847    elif reduction == 'sum':
2848        return sum(losses_tensor)
2849    else:
2850        return losses_tensor
2851
2852
2853def smoothl1loss_reference(input, target, reduction='mean', beta=1.0):
2854    abs_diff = (input - target).abs()
2855    ge_beta_mask = (abs_diff >= beta).type_as(abs_diff)
2856    lt_beta_mask = (abs_diff < beta).type_as(abs_diff)
2857    # when beta <= 0 we should just use l1_loss
2858    if beta == 0:
2859        output = abs_diff
2860    else:
2861        output = ge_beta_mask * (abs_diff - 0.5 * beta) + lt_beta_mask * 0.5 * (abs_diff ** 2) / beta
2862    if reduction == 'mean':
2863        return output.mean()
2864    elif reduction == 'sum':
2865        return output.sum()
2866    return output
2867
2868
2869def huberloss_reference(input, target, reduction='mean', delta=1.0):
2870    abs_diff = (input - target).abs()
2871    ge_delta_mask = (abs_diff >= delta)
2872    lt_delta_mask = (abs_diff < delta)
2873    output = ge_delta_mask * delta * (abs_diff - 0.5 * delta) + lt_delta_mask * 0.5 * (abs_diff ** 2)
2874    if reduction == 'mean':
2875        return output.mean()
2876    elif reduction == 'sum':
2877        return output.sum()
2878    return output
2879
2880
2881def _multilabelmarginloss_reference(input, target):
2882    targets = []
2883    for target_index in target:
2884        if target_index < 0:
2885            break
2886        targets.append(target_index)
2887
2888    sum = 0
2889    for target_index in targets:
2890        for i in range(0, len(input)):
2891            if i not in targets:
2892                sum += max(0, 1 - input[target_index] + input[i])
2893
2894    return sum
2895
2896
2897def multilabelmarginloss_reference(input, target, reduction='mean'):
2898    # make everything 2-dimensional
2899    input_dim = input.dim()
2900    if input.dim() < 2:
2901        assert target.dim() < 2
2902        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
2903        target = target.unsqueeze(0) if target.dim() == 1 else target.unsqueeze(0).unsqueeze(0)
2904
2905    n = input.size(0)
2906    dim = input.size(1)
2907    output = input.new(n).zero_()
2908    for i in range(0, n):
2909        output[i] = _multilabelmarginloss_reference(input[i], target[i])
2910
2911    if reduction == 'mean':
2912        return output.mean() / dim
2913    elif reduction == 'sum':
2914        return output.sum() / dim
2915    elif input_dim < 2:
2916        # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
2917        # back to correct dimensionality
2918        return output.squeeze() / dim
2919    else:
2920        return output / dim
2921
2922
2923def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
2924    margin_clamp = (margin - input).clamp(min=0).type_as(input)
2925    output = torch.where(target == 1, input, margin_clamp)
2926
2927    if reduction == 'mean':
2928        return output.mean()
2929    elif reduction == 'sum':
2930        return output.sum()
2931    return output
2932
2933
2934def softmarginloss_reference(input, target, reduction='mean'):
2935    output = (1 + (-input * target).exp()).log()
2936
2937    if reduction == 'mean':
2938        return output.mean()
2939    elif reduction == 'sum':
2940        return output.sum()
2941    return output
2942
2943
2944def _multimarginloss_reference(input, target_idx, p, margin, weight):
2945    if weight is None:
2946        weight = input.new(len(input)).fill_(1)
2947
2948    output = 0
2949    for i in range(0, len(input)):
2950        if i != target_idx:
2951            output += weight[target_idx] * (max(0, (margin - input[target_idx] + input[i])) ** p)
2952    return output
2953
2954
2955def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
2956    if input.dim() < 2:
2957        input = input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
2958
2959    target_dim = target.dim()
2960    if target.dim() == 0:
2961        target = target.unsqueeze(0)
2962
2963    n = input.size(0)
2964    dim = input.size(1)
2965    output = input.new(n)
2966    for x in range(0, n):
2967        output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
2968
2969    if reduction == 'mean':
2970        return output.mean() / dim
2971    elif reduction == 'sum':
2972        return output.sum() / dim
2973    elif target_dim == 0:
2974        return output.squeeze(0) / dim
2975    return output / dim
2976
2977
2978def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
2979    def _cos(a, b):
2980        cos = a.new(a.size(0))
2981        for i in range(0, a.size(0)):
2982            cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
2983        return cos
2984
2985    output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
2986
2987    if reduction == 'mean':
2988        return output.mean()
2989    elif reduction == 'sum':
2990        return output.sum()
2991    return output
2992
2993
2994def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
2995                                reduction='mean'):
2996    d_p = torch.pairwise_distance(anchor, positive, p, eps)
2997    d_n = torch.pairwise_distance(anchor, negative, p, eps)
2998    if swap:
2999        d_s = torch.pairwise_distance(positive, negative, p, eps)
3000        d_n = torch.min(d_n, d_s)
3001
3002    output = torch.clamp(margin + d_p - d_n, min=0.0)
3003    if reduction == 'mean':
3004        return output.mean()
3005    elif reduction == 'sum':
3006        return output.sum()
3007    return output
3008
3009
3010def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
3011    output = (-target * (input1 - input2) + margin).clamp(min=0)
3012    if reduction == 'mean':
3013        return output.mean()
3014    elif reduction == 'sum':
3015        return output.sum()
3016    return output
3017
3018
3019# this directly follows Graves et al.'s paper, in contrast to the production implementation, it does not use log-space
3020def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
3021    input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
3022    target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
3023    dt = log_probs.dtype
3024    log_probs = log_probs.double()  # we need the accuracy as we are not in logspace
3025    targets = targets.long()
3026    cum_target_lengths = target_lengths.cumsum(0)
3027    losses = []
3028    for i in range(log_probs.size(1)):
3029        input_length = input_lengths[i].item()
3030        target_length = target_lengths[i].item()
3031        cum_target_length = cum_target_lengths[i].item()
3032        targets_prime = targets.new_full((2 * target_length + 1,), blank)
3033        if targets.dim() == 2:
3034            targets_prime[1::2] = targets[i, :target_length]
3035        else:
3036            targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
3037        probs = log_probs[:input_length, i].exp()
3038        alpha = log_probs.new_zeros((target_length * 2 + 1,))
3039        alpha[0] = probs[0, blank]
3040        alpha[1] = probs[0, targets_prime[1]]
3041        mask_third = (targets_prime[:-2] != targets_prime[2:])
3042        for t in range(1, input_length):
3043            alpha_next = alpha.clone()
3044            alpha_next[1:] += alpha[:-1]
3045            alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
3046            alpha = probs[t, targets_prime] * alpha_next
3047        losses.append(-alpha[-2:].sum().log()[None])
3048    output = torch.cat(losses, 0)
3049    if reduction == 'mean':
3050        output = (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
3051    elif reduction == 'sum':
3052        output = output.sum()
3053    output = output.to(dt)
3054    return output
3055
3056
3057loss_reference_fns: Dict['str', Callable] = {
3058    'KLDivLoss': kldivloss_reference,
3059    'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
3060    'NLLLoss': nllloss_reference,
3061    'NLLLossNd': nlllossNd_reference,
3062    'SmoothL1Loss': smoothl1loss_reference,
3063    'HuberLoss': huberloss_reference,
3064    'MultiLabelMarginLoss': multilabelmarginloss_reference,
3065    'HingeEmbeddingLoss': hingeembeddingloss_reference,
3066    'SoftMarginLoss': softmarginloss_reference,
3067    'MultiMarginLoss': multimarginloss_reference,
3068    'CosineEmbeddingLoss': cosineembeddingloss_reference,
3069    'TripletMarginLoss': tripletmarginloss_reference,
3070    'MarginRankingLoss': marginrankingloss_reference,
3071    'CTCLoss': ctcloss_reference,
3072    'CrossEntropyLoss': cross_entropy_loss_reference
3073}
3074
3075
3076criterion_tests = []
3077
3078
3079def single_batch_reference_criterion_fn(*args):
3080    """Reference function for criterion supporting no batch dimensions.
3081
3082    The criterion is passed the input and target in batched form with a single item.
3083    The output is squeezed to compare with the no-batch input.
3084    """
3085    criterion = args[-1]
3086
3087    def unsqueeze_inp(inp):
3088        if isinstance(inp, (list, tuple)):
3089            return [t.unsqueeze(0) for t in inp]
3090        return inp.unsqueeze(0)
3091
3092    def flatten(xs):
3093        result = []
3094        if isinstance(xs, (list, tuple)):
3095            for x in xs:
3096                result.extend(flatten(x))
3097        else:
3098            result.append(xs)
3099        return result
3100
3101    single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])
3102
3103    output = criterion(*single_batch_input_args)
3104    reduction = get_reduction(criterion)
3105
3106    if reduction == 'none':
3107        return output.squeeze(0)
3108    # reduction is 'sum' or 'mean' which results in a scalar
3109    return output
3110
3111
3112# Check that regression criterion work with no batch dimensions
3113regression_criterion_no_batch = [
3114    'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
3115]
3116reductions = ['none', 'mean', 'sum']
3117for name, reduction in product(regression_criterion_no_batch, reductions):
3118    regression_test_info = dict(
3119        fullname=f"{name}_no_batch_dim_{reduction}",
3120        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
3121        input_size=(3, ),
3122        target_size=(3, ),
3123        reference_fn=single_batch_reference_criterion_fn,
3124        test_cpp_api_parity=False,
3125        default_dtype=torch.double,
3126    )
3127    criterion_tests.append(regression_test_info)
3128
3129
3130for reduction in reductions:
3131    regression_test_info = dict(
3132        fullname=f"KLDivLoss_no_batch_dim_{reduction}",
3133        constructor=lambda: nn.KLDivLoss(reduction=reduction),
3134        input_fn=lambda: torch.rand((3,)).log(),
3135        target_fn=lambda: torch.rand((3,)),
3136        reference_fn=single_batch_reference_criterion_fn,
3137        test_cpp_api_parity=False,
3138        default_dtype=torch.double,
3139    )
3140    criterion_tests.append(regression_test_info)
3141
3142
3143# Check that classification criterion work with no batch dimensions
3144# List of tuples of (name, input_fn, target_fn)
3145classification_criterion_no_batch = [
3146    (
3147        'BCELoss',
3148        lambda: torch.sigmoid(torch.randn(9, dtype=torch.double)),
3149        lambda: torch.randn(9, dtype=torch.double).gt(0).to(torch.double)
3150    ),
3151    ('BCEWithLogitsLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9, dtype=torch.double)),
3152    ('HingeEmbeddingLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
3153    ('MultiLabelMarginLoss', lambda: torch.randn(4, dtype=torch.double), lambda: torch.tensor([3, 0, -1, 1])),
3154    ('SoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.tensor([-1, 1, 1] * 3)),
3155    ('NLLLoss', lambda: F.log_softmax(torch.randn(3, dtype=torch.double), dim=0), lambda: torch.tensor(1)),
3156    (
3157        'CosineEmbeddingLoss',
3158        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
3159        lambda: torch.tensor(1, dtype=torch.double)
3160    ),
3161    # For MarginRankingLoss, input_fn : (x1, x2) and target_fn : target
3162    ('MarginRankingLoss', lambda: (torch.randn(()), torch.randn(())), lambda: torch.randn(()).sign()),
3163    # For TripletMarginLoss, input_fn : (anchor, positive) and target_fn : negative
3164    (
3165        'TripletMarginLoss',
3166        lambda: (torch.randn(9, dtype=torch.double), torch.randn(9, dtype=torch.double)),
3167        lambda: torch.randn(9, dtype=torch.double)
3168    ),
3169    ('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
3170]
3171classification_criterion_no_batch_extra_info: Dict[str, dict] = {
3172    'MultiLabelMarginLoss': {'check_gradgrad': False},
3173}
3174# TODO : Fix these discrepancies
3175classification_cpp_parity = {
3176    'BCELoss': False,
3177    'BCEWithLogitsLoss': False,
3178    'HingeEmbeddingLoss': False,
3179    'NLLLoss': False,
3180    'SoftMarginLoss': False,
3181}
3182reductions = ['none', 'mean', 'sum']
3183for (name, input_fn, target_fn), reduction in product(classification_criterion_no_batch,
3184                                                      reductions):
3185    classification_test_info = dict(
3186        fullname=f"{name}_no_batch_dim_{reduction}",
3187        constructor=lambda *args, name=name: getattr(nn, name)(reduction=reduction),
3188        input_fn=lambda f=input_fn: f(),
3189        target_fn=lambda f=target_fn: f(),
3190        reference_fn=single_batch_reference_criterion_fn,
3191        test_cpp_api_parity=True,
3192        has_parity=classification_cpp_parity.get(name, True)
3193    )
3194    extra_info = classification_criterion_no_batch_extra_info.get(name, {})
3195    classification_test_info.update(extra_info)
3196    criterion_tests.append(classification_test_info)
3197
3198
3199class NNTestCase(TestCase):
3200
3201    # _forward is defined in classes inheriting from NNTestCase
3202    @abstractmethod
3203    def _forward(self, *args, **kwargs):
3204        raise NotImplementedError
3205
3206    @abstractmethod
3207    def _get_parameters(self, module: nn.Module) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
3208        raise NotImplementedError
3209
3210    @abstractmethod
3211    def _zero_grad_parameters(self, module: nn.Module) -> None:
3212        raise NotImplementedError
3213
3214    @abstractmethod
3215    def _backward(self, module: nn.Module,
3216                  input: _TensorOrTensors, output: torch.Tensor,
3217                  grad_output: Union[torch.Tensor, Sequence[torch.Tensor]],
3218                  create_graph: bool = False):
3219        raise NotImplementedError
3220
3221    def _jacobian(self, input, num_out):
3222        if isinstance(input, tuple):
3223            return tuple(self._jacobian(elem, num_out) for elem in input)
3224        elif isinstance(input, list):
3225            return [self._jacobian(elem, num_out) for elem in input]
3226        else:
3227            return torch.zeros(input.nelement(), num_out)
3228
3229    def _flatten_tensors(self, x):
3230        if isinstance(x, torch.Tensor):
3231            if x.is_sparse:
3232                return x.to_dense().view(-1)
3233            else:
3234                return x.view(-1)
3235        else:
3236            return tuple(self._flatten_tensors(a) for a in x)
3237
3238    def _zero_grad_input(self, input):
3239        if isinstance(input, torch.Tensor):
3240            if input.requires_grad and input.grad is not None:
3241                input.grad.zero_()
3242                input.grad.detach_()
3243        else:
3244            for i in input:
3245                self._zero_grad_input(i)
3246
3247    def _analytical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
3248        output = self._forward(module, input)
3249        output_size = output.nelement()
3250
3251        if jacobian_input:
3252            jacobian_inp = self._jacobian(input, output_size)
3253            flat_jacobian_input = list(_iter_tensors(jacobian_inp))
3254
3255        if jacobian_parameters:
3256            num_param = sum(p.numel() for p in self._get_parameters(module)[0])
3257            jacobian_param = torch.zeros(num_param, output_size)
3258
3259        for i in range(output_size):
3260            param, d_param = self._get_parameters(module)
3261            # make non grad zeros
3262            d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)]
3263
3264            d_out = torch.zeros_like(output)
3265            flat_d_out = d_out.view(-1)
3266            flat_d_out[i] = 1
3267
3268            if jacobian_parameters:
3269                self._zero_grad_parameters(module)
3270            # Tensors will accumulate gradient from multiple steps
3271            if jacobian_input:
3272                self._zero_grad_input(input)
3273            d_input = self._backward(module, input, output, d_out)
3274
3275            if jacobian_input:
3276                for jacobian_x, d_x in zip(flat_jacobian_input, _iter_tensors(d_input)):
3277                    jacobian_x[:, i] = d_x.contiguous().view(-1)
3278            if jacobian_parameters:
3279                jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
3280
3281        res: Tuple[torch.Tensor, ...] = ()
3282        if jacobian_input:
3283            res += jacobian_inp,
3284        if jacobian_parameters:
3285            res += jacobian_param,
3286
3287        return res
3288
3289    def _numerical_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True, jacobian_parameters=True):
3290        def fw(*input):
3291            return self._forward(module, input).detach()
3292
3293        res: Tuple[torch.Tensor, ...] = ()
3294        if jacobian_input:
3295            res += _get_numerical_jacobian(fw, input, eps=1e-6),
3296        if jacobian_parameters:
3297            param, _ = self._get_parameters(module)
3298            to_cat = []
3299            for p in param:
3300                jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
3301                # get_numerical_jacobian returns a list of tuples but we require a tensor
3302                to_cat.append(jacobian[0][0])
3303            res += (torch.cat(to_cat, 0),)
3304        return res
3305
3306    def check_jacobian(self, module, input: _TensorOrTensors, jacobian_input=True):
3307        jacobian_parameters = bool(self._get_parameters(module)[0])
3308        analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
3309        numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
3310        analytical_t = list(_iter_tensors(analytical))
3311        numerical_t = list(_iter_tensors(numerical))
3312
3313        differences = []
3314        for a, n in zip(analytical_t, numerical_t):
3315            if a.numel() != 0:
3316                differences.append(a.add(n, alpha=-1).abs().max())
3317            # TODO: compare structure (ensure analytic jacobian has correct shape)
3318        if len(differences) > 0:
3319            self.assertLessEqual(max(differences), PRECISION)  # type: ignore[type-var]
3320
3321
3322class TestBase:
3323
3324    _required_arg_names = {'constructor_args', 'input', 'extra_args'}
3325
3326    def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
3327        self.desc = desc
3328        self.fullname = fullname
3329        self.constructor = constructor
3330        self.reference_fn = reference_fn
3331        for name in self._required_arg_names:
3332            if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
3333                if name in {'constructor_args', 'extra_args'}:
3334                    kwargs[name] = ()
3335                else:
3336                    raise ValueError(f"{self.get_name()}: Specify {name} by a value, a function to generate it, or it's size!")
3337        self._extra_kwargs = kwargs
3338        self._arg_cache = {}
3339
3340    def get_name(self):
3341        if self.fullname is not None:
3342            return 'test_' + self.fullname
3343
3344        test_name = 'test_' + self.constructor.__name__
3345        if self.desc:
3346            test_name += '_' + self.desc
3347        return test_name
3348
3349    def _unpack(self, value):
3350        if isinstance(value, torch.Tensor):
3351            return value
3352        elif is_iterable(value):
3353            return type(value)(self._unpack(v) for v in value)
3354        else:
3355            return value
3356
3357    @property
3358    def constructor_args(self):
3359        return self._get_arg('constructor_args', True)
3360
3361    @property
3362    def extra_args(self):
3363        return self._get_arg('extra_args', True)
3364
3365    def _get_arg(self, name, unpack):
3366        assert name in self._required_arg_names
3367
3368        if name not in self._arg_cache:
3369            fn_name = name + '_fn'
3370            size_name = name + '_size'
3371
3372            if name in self._extra_kwargs:
3373                self._arg_cache[name] = self._extra_kwargs[name]
3374            elif fn_name in self._extra_kwargs:
3375                self._arg_cache[name] = self._extra_kwargs[fn_name]()
3376            else:
3377                assert size_name in self._extra_kwargs, \
3378                    f"Missing `{name}`, `{size_name}` or `{fn_name}` for {self.get_name()}"
3379
3380                def map_tensor_sizes(sizes):
3381                    if isinstance(sizes, list):
3382                        return [map_tensor_sizes(s) for s in sizes]
3383                    elif isinstance(sizes, torch.Tensor):
3384                        return sizes.double()
3385                    else:
3386                        return torch.randn(sizes)
3387
3388                self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
3389
3390        return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
3391
3392    def _get_input(self, unpack=True):
3393        return self._get_arg('input', unpack)
3394
3395    def __call__(self, test_case):
3396        raise NotImplementedError
3397
3398
3399class ModuleTest(TestBase):
3400
3401    @abstractmethod
3402    def _do_test(self, test_case: Any, module: nn.Module, input: Any) -> Any:
3403        raise NotImplementedError
3404
3405    def __init__(self, *args, **kwargs):
3406        super().__init__(*args, **kwargs)
3407        self.jacobian_input = kwargs.get('jacobian_input', True)
3408        self.should_test_cuda = kwargs.get('test_cuda', True)
3409        self.should_test_pickle = kwargs.get('pickle', True)
3410        self.check_gradgrad = kwargs.get('check_gradgrad', True)
3411        self.FIXME_no_cuda_gradgrad_comparison = \
3412            kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
3413        self.precision = kwargs.get('precision', 2e-4)
3414        self.check_forward_only = kwargs.get('check_forward_only', False)
3415        self.default_dtype = kwargs.get('default_dtype', None)
3416        if self.default_dtype is None:
3417            self.default_dtype = torch.get_default_dtype()
3418
3419    def __call__(self, test_case):
3420        with set_default_dtype(self.default_dtype):
3421            module = self.constructor(*self.constructor_args)
3422            input = self._get_input()
3423
3424            if self.reference_fn is not None:
3425                out = test_case._forward(module, input)
3426                ref_input = deepcopy(input)
3427                ref_module = deepcopy(module)
3428                expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0], ref_module)
3429                test_case.assertEqual(out, expected_out, exact_dtype=False)
3430            if self.check_forward_only:
3431                return
3432            self.test_noncontig(test_case, module, input)
3433
3434            if self.should_test_pickle:
3435                # TODO: do this with in-memory files as soon as torch.save will support it
3436                with tempfile.TemporaryFile() as f:
3437                    test_case._forward(module, input)
3438                    torch.save(module, f)
3439                    f.seek(0)
3440                    # weights_only=False as this is legacy code that saves the model
3441                    module_copy = torch.load(f, weights_only=False)
3442                    test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
3443
3444            self._do_test(test_case, module, input)
3445
3446    def noncontiguize(self, obj):
3447        if isinstance(obj, list):
3448            return [self.noncontiguize(o) for o in obj]
3449        elif isinstance(obj, tuple):
3450            return tuple(self.noncontiguize(o) for o in obj)
3451        tensor = obj
3452        ndim = tensor.dim()
3453        # Always making only the last dimension noncontiguous is easy to hide
3454        # bugs because .view(-1) will still work. So try to find a dim with size
3455        # > 1 and make that non-contiguous, i.e., stack + select on the
3456        # dimension directly after that.
3457        dim = ndim
3458        for d in range(ndim):
3459            if tensor.size(d) > 1:
3460                dim = d + 1
3461                break
3462        noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
3463        assert noncontig.numel() == 1 or noncontig.numel() == 0 or not noncontig.is_contiguous()
3464        noncontig.requires_grad = tensor.requires_grad
3465        return noncontig
3466
3467    def test_noncontig(self, test_case, module, input):
3468        # check no scalars, can't make non-contig
3469        if isinstance(input, torch.Tensor) and input.dim() == 0:
3470            return
3471        if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
3472            return
3473
3474        test_case._zero_grad_parameters(module)
3475        test_case._zero_grad_input(input)
3476        with freeze_rng_state():
3477            output = test_case._forward(module, input)
3478            if getattr(module, "return_indices", False):
3479                output = output[0]
3480            grad_output = output.new(output.shape).normal_()
3481            output = output.clone()
3482            d_input = deepcopy(test_case._backward(module, input, output, grad_output))
3483            d_param = deepcopy(test_case._get_parameters(module)[1])
3484
3485        nc_input = self.noncontiguize(input)
3486        nc_grad_output = self.noncontiguize(grad_output)
3487        for contig_i, contig_g in product((True, False), repeat=2):
3488            i = input if contig_i else nc_input
3489            # Some ops, e.g., nn.Flatten, return gradient that shares
3490            # storage with the grad_output. Hence we copy here.
3491            go = deepcopy(grad_output if contig_g else nc_grad_output)
3492            test_case._zero_grad_parameters(module)
3493            test_case._zero_grad_input(i)
3494            with freeze_rng_state():
3495                out = test_case._forward(module, i)
3496                if getattr(module, "return_indices", False):
3497                    out = out[0]
3498                grad = test_case._backward(module, i, out, go)
3499
3500                test_case.assertEqual(out, output)
3501                test_case.assertEqual(grad, d_input, atol=1e-4, rtol=0)
3502                test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
3503
3504    def test_cuda(self, test_case):
3505        if not TEST_CUDA or not self.should_test_cuda:
3506            raise unittest.SkipTest('Excluded from CUDA tests')
3507
3508        with set_default_dtype(self.default_dtype):
3509            cpu_input = self._get_input()
3510
3511            type_map = {torch.double: torch.float}
3512            cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
3513
3514            is_any_input_complex = any(isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple)
3515
3516            gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
3517
3518            cpu_module = self.constructor(*self.constructor_args)
3519            gpu_module = self.constructor(*self.constructor_args).float().cuda()
3520            cpu_param = test_case._get_parameters(cpu_module)
3521            gpu_param = test_case._get_parameters(gpu_module)
3522            for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
3523                gpu_p.data.copy_(cpu_p)
3524
3525            test_case._zero_grad_input(cpu_input_tuple)
3526            test_case._zero_grad_input(gpu_input_tuple)
3527            test_case._zero_grad_parameters(cpu_module)
3528            test_case._zero_grad_parameters(gpu_module)
3529            cpu_output = test_case._forward(cpu_module, cpu_input_tuple)
3530            gpu_output = test_case._forward(gpu_module, gpu_input_tuple)
3531            if getattr(cpu_module, "return_indices", False):
3532                cpu_output = cpu_output[0]
3533                gpu_output = gpu_output[0]
3534            test_case.assertEqual(cpu_output, gpu_output, atol=self.precision, rtol=0, exact_dtype=False)
3535
3536            # Run backwards on CPU and GPU and compare results
3537            for _ in range(5):
3538                cpu_gradOutput = cpu_output.clone().normal_()
3539                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output)
3540                cpu_gradInput = test_case._backward(cpu_module, cpu_input_tuple, cpu_output, cpu_gradOutput)
3541                gpu_gradInput = test_case._backward(gpu_module, gpu_input_tuple, gpu_output, gpu_gradOutput)
3542                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
3543                for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
3544                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)
3545
3546            # Run double-backwards on CPU and GPU and compare results
3547            if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
3548                cpu_output = cpu_module(*cpu_input_tuple)
3549                gpu_output = gpu_module(*gpu_input_tuple)
3550                if getattr(cpu_module, "return_indices", False):
3551                    cpu_output = cpu_output[0]
3552                    gpu_output = gpu_output[0]
3553
3554                cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
3555                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
3556                gpu_gradOutput.requires_grad = True
3557
3558                cpu_gradInputs = torch.autograd.grad(
3559                    cpu_output,
3560                    cpu_input_tuple + tuple(cpu_module.parameters()),
3561                    cpu_gradOutput,
3562                    create_graph=True)
3563                gpu_gradInputs = torch.autograd.grad(
3564                    gpu_output,
3565                    gpu_input_tuple + tuple(gpu_module.parameters()),
3566                    gpu_gradOutput,
3567                    create_graph=True)
3568
3569                for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
3570                    test_case.assertEqual(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0, exact_dtype=False)
3571
3572                # We mix output into the second backwards computation so that
3573                # torch.autograd.grad doesn't complain that some inputs
3574                # are unreachable (which can happen if you differentiate
3575                # only on the gradient.
3576                if is_any_input_complex:
3577                    outputs_cpu = cpu_output.sum().abs() + sum(x.sum().abs() for x in cpu_gradInputs)
3578                    outputs_gpu = gpu_output.sum().abs() + sum(x.sum().abs() for x in gpu_gradInputs)
3579                else:
3580                    outputs_cpu = cpu_output.sum() + sum(x.sum() for x in cpu_gradInputs)
3581                    outputs_gpu = gpu_output.sum() + sum(x.sum() for x in gpu_gradInputs)
3582
3583                cpu_gg = torch.autograd.grad(
3584                    outputs_cpu,
3585                    cpu_input_tuple + (cpu_gradOutput,) + tuple(cpu_module.parameters()),
3586                    retain_graph=True)
3587                gpu_gg = torch.autograd.grad(
3588                    outputs_gpu,
3589                    gpu_input_tuple + (gpu_gradOutput,) + tuple(gpu_module.parameters()),
3590                    retain_graph=True)
3591                test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0, exact_dtype=False)
3592                for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
3593                    test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0, exact_dtype=False)
3594
3595            self.test_noncontig(test_case, gpu_module, gpu_input_tuple)
3596
3597
3598class InputVariableMixin:
3599    def _get_input(self):
3600        input = TestBase._get_input(self, False)  # type: ignore[arg-type]
3601
3602        def map_variables(i):
3603            if isinstance(i, torch.Tensor):
3604                if i.is_floating_point() or i.is_complex():
3605                    i.requires_grad = True
3606                return i
3607            else:
3608                return type(i)(map_variables(elem) for elem in i)
3609
3610        return map_variables(input)
3611
3612
3613class NewModuleTest(InputVariableMixin, ModuleTest):  # type: ignore[misc]
3614    def __init__(self, *args, **kwargs):
3615        super().__init__(*args, **kwargs)
3616        self.cudnn = kwargs.get('cudnn', False)
3617        self.check_inplace = kwargs.get('check_inplace', False)
3618        self.check_gradgrad = kwargs.get('check_gradgrad', True)
3619        self.skip_double = kwargs.get('skip_double', False)
3620        self.skip_half = kwargs.get('skip_half', False)
3621        self.with_tf32 = kwargs.get('with_tf32', False)
3622        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
3623        self.test_cpu = kwargs.get('test_cpu', True)
3624        self.has_sparse_gradients = kwargs.get('has_sparse_gradients', False)
3625        self.check_batched_grad = kwargs.get('check_batched_grad', True)
3626        self.gradcheck_fast_mode = kwargs.get('gradcheck_fast_mode', None)
3627        self.supports_forward_ad = kwargs.get('supports_forward_ad', False)
3628        self.supports_fwgrad_bwgrad = kwargs.get('supports_fwgrad_bwgrad', False)
3629
3630    def _check_gradients(self, test_case, module, input_tuple):
3631        params = tuple(x for x in module.parameters())
3632        num_inputs = len(input_tuple)
3633
3634        def fn_to_gradcheck(*inputs_and_params, **kwargs):
3635            assert not kwargs
3636            return test_case._forward(module, inputs_and_params[:num_inputs])
3637
3638        # gradcheck doesn't support operators that take in dense inputs but
3639        # return sparse parameters. This only happens in the case of nn.Embedding
3640        # and nn.EmbeddingBag. Instead, we call `self.check_jacobian`, which
3641        # is a slightly different version of gradcheck that can handle this.
3642        if self.has_sparse_gradients:
3643            assert num_inputs == 1
3644            test_input_jacobian = torch.is_floating_point(input_tuple[0])
3645            test_case.check_jacobian(module, input_tuple[0], test_input_jacobian)
3646        else:
3647            test_case.assertTrue(gradcheck(fn_to_gradcheck, input_tuple + params,
3648                                           check_batched_grad=self.check_batched_grad,
3649                                           fast_mode=self.gradcheck_fast_mode,
3650                                           check_forward_ad=self.supports_forward_ad))
3651
3652        if self.check_gradgrad:
3653            test_case.assertTrue(gradgradcheck(fn_to_gradcheck, input_tuple + params,
3654                                               check_batched_grad=self.check_batched_grad,
3655                                               fast_mode=self.gradcheck_fast_mode,
3656                                               check_fwd_over_rev=self.supports_fwgrad_bwgrad))
3657
3658    def _do_test(self, test_case, module, input):
3659        num_threads = torch.get_num_threads()
3660        torch.set_num_threads(1)
3661        input_tuple = input if isinstance(input, tuple) else (input,)
3662
3663        self._check_gradients(test_case, module, input_tuple)
3664
3665        # check if module can be printed
3666        module.__repr__()
3667
3668        if self.check_inplace:
3669            # check if the inplace variant of the module gives the same result
3670            # as the out-of-place
3671
3672            # check_inplace doesn't support multiple input tensors, since we don't have any modules
3673            # that modify the inputs in-place and that accept more than one input
3674            assert len(input_tuple) == 1
3675            input = input_tuple[0]
3676
3677            module_ip = self.constructor(*self.constructor_args, inplace=True)
3678
3679            input_version = input._version
3680            with freeze_rng_state():
3681                output = module(input)
3682            test_case.assertEqual(input._version, input_version)
3683
3684            input_ip = deepcopy(input)
3685            input_ip_clone = input_ip.clone()
3686            with freeze_rng_state():
3687                output_ip = module_ip(input_ip_clone)
3688            test_case.assertNotEqual(input_ip_clone._version, input_version)
3689            test_case.assertEqual(output, output_ip)
3690            grad = output.data.clone().normal_()
3691            if input.grad is not None:
3692                with torch.no_grad():
3693                    input.grad.zero_()
3694            if input_ip.grad is not None:
3695                with torch.no_grad():
3696                    input_ip.grad.zero_()
3697            output.backward(grad)
3698            output_ip.backward(grad)
3699            test_case.assertEqual(input.grad, input_ip.grad)
3700
3701        def assert_module_parameters_are(tensor_type, device_id=None):
3702            for p in module.parameters():
3703                test_case.assertIsInstance(p, tensor_type)
3704                if device_id is not None:
3705                    test_case.assertEqual(p.get_device(), device_id)
3706
3707        if all(isinstance(t, torch.LongTensor) for t in input_tuple) and TEST_CUDA:
3708            # check that cuda() moves module parameters to correct GPU device,
3709            # and that float() casts parameters correctly
3710            input_tuple = tuple(t.cuda() for t in input_tuple)
3711            module.float().cuda()
3712            module(*input_tuple)
3713            assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
3714
3715            if torch.cuda.device_count() > 1:
3716                input_tuple = tuple(t.cuda(1) for t in input_tuple)
3717                module.cuda(1)
3718                with torch.cuda.device(1):
3719                    module(*input_tuple)
3720                assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
3721        else:
3722            # check that float()/double() casters work correctly
3723            def to_type(tensor, real, complex):
3724                if tensor.is_complex():
3725                    return tensor.to(complex)
3726                elif tensor.is_floating_point():
3727                    return tensor.to(real)
3728                else:
3729                    return tensor
3730
3731            def to_half(x):
3732                # TODO: torch.complex32 when properly supported
3733                return to_type(x, torch.float16, None)
3734
3735            def to_single(x):
3736                return to_type(x, torch.float32, torch.complex64)
3737
3738            def to_double(x):
3739                return to_type(x, torch.float64, torch.complex128)
3740
3741            # to float
3742            input_tuple = tuple(to_single(t) for t in input_tuple)
3743            module.float()
3744            module(*input_tuple)
3745            assert_module_parameters_are(torch.FloatTensor)
3746
3747            # and back to double
3748            input_tuple = tuple(to_double(t) for t in input_tuple)
3749            module.double()
3750            module(*input_tuple)
3751            assert_module_parameters_are(torch.DoubleTensor)
3752
3753            if TEST_CUDA and self.should_test_cuda:
3754                # check that cuda() moves module parameters to correct GPU device,
3755                # and that float() casts parameters correctly
3756
3757                # to GPU0
3758                input_tuple = tuple(to_single(t).cuda() for t in input_tuple)
3759                module.float().cuda()
3760                module(*input_tuple)
3761                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
3762
3763                # to CPU
3764                input_tuple = tuple(t.cpu() for t in input_tuple)
3765                module.cpu()
3766                module(*input_tuple)
3767                assert_module_parameters_are(torch.FloatTensor)
3768
3769                # back to GPU0
3770                input_tuple = tuple(t.cuda() for t in input_tuple)
3771                module.cuda()
3772                module(*input_tuple)
3773                assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
3774
3775                # test that forwards of module runs correctly without cuDNN
3776                if self.cudnn:
3777                    with torch.backends.cudnn.flags(enabled=False):
3778                        module(*input_tuple)
3779                        assert_module_parameters_are(torch.cuda.FloatTensor, 0)  # type: ignore[attr-defined]
3780
3781                if torch.cuda.device_count() >= 2:
3782                    # test cross-GPU transfer works
3783                    # to GPU1
3784                    input_tuple = tuple(t.cuda(1) for t in input_tuple)
3785                    module.cuda(1)
3786                    with torch.cuda.device(1):
3787                        module(*input_tuple)
3788                    assert_module_parameters_are(torch.cuda.FloatTensor, 1)  # type: ignore[attr-defined]
3789
3790                if not self.skip_double:
3791                    # test double()
3792                    input_tuple = tuple(to_double(t).cuda() for t in input_tuple)
3793                    module.double().cuda()
3794                    module(*input_tuple)
3795                    assert_module_parameters_are(torch.cuda.DoubleTensor, 0)  # type: ignore[attr-defined]
3796
3797                # test half()
3798                if not self.skip_half:
3799                    input_tuple = tuple(to_half(t).cuda() for t in input_tuple)
3800                    module.half().cuda()
3801                    module(*input_tuple)
3802                    assert_module_parameters_are(torch.cuda.HalfTensor, 0)  # type: ignore[attr-defined]
3803        torch.set_num_threads(num_threads)
3804
3805    def _get_target(self):
3806        return self._get_arg('target', False)
3807
3808    @property
3809    def constructor_args(self):
3810        return self._get_arg('constructor_args', False)
3811
3812
3813class CriterionTest(InputVariableMixin, TestBase):  # type: ignore[misc]
3814    # TODO: check that criterions don't ignore grad_output
3815
3816    _required_arg_names = TestBase._required_arg_names.union({'target'})
3817
3818    def __init__(self, *args, **kwargs):
3819        super().__init__(*args, **kwargs)
3820        self.should_test_cuda = kwargs.get('test_cuda', True)
3821        self.check_forward_only = kwargs.get('check_forward_only', False)
3822        self.check_gradgrad = kwargs.get('check_gradgrad', True)
3823        self.check_half = kwargs.get('check_half', True)
3824        self.check_bfloat16 = kwargs.get('check_bfloat16', False)
3825        self.check_complex = kwargs.get('check_complex', False)
3826        self.test_cpu = kwargs.get('test_cpu', True)
3827        self.with_tf32 = kwargs.get('with_tf32', True)
3828        self.tf32_precision = kwargs.get('tf32_precision', 0.001)
3829        self.check_batched_grad = kwargs.get('check_batched_grad', True)
3830        self.default_dtype = kwargs.get('default_dtype', None)
3831        if self.default_dtype is None:
3832            self.default_dtype = torch.get_default_dtype()
3833
3834    def __call__(self, test_case):
3835        with set_default_dtype(self.default_dtype):
3836            module = self.constructor(*self.constructor_args)
3837            input = self._get_input()
3838
3839            # Check that these methods don't raise errors
3840            module.__repr__()
3841            str(module)
3842
3843            target = self._get_target()
3844
3845            if self.reference_fn is not None:
3846                out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
3847                ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
3848                expected_out = self.reference_fn(*ref_args)
3849                test_case.assertEqual(out, expected_out)
3850
3851            if self.check_forward_only:
3852                return
3853
3854            params = tuple(x for x in module.parameters())
3855            if not isinstance(input, tuple):
3856                inputs = (input,) + params + (target,)
3857
3858                def apply_fn(input, target, *params):
3859                    return module(input, target)
3860            else:
3861                inputs = input + params + (target,)
3862
3863                def apply_fn(input1, input2, target, *params):  # type: ignore[misc]
3864                    return module(input1, input2, target)
3865
3866            gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
3867
3868            if self.check_gradgrad:
3869                gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad)
3870
3871    def test_cuda(self, test_case, dtype, extra_args=None):
3872        def convert_dtype(obj, dtype, requires_grad=False):
3873            if isinstance(obj, torch.Tensor):
3874                return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
3875            elif isinstance(obj, tuple):
3876                return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
3877            else:
3878                return obj
3879
3880        if not TEST_CUDA or not self.should_test_cuda:
3881            raise unittest.SkipTest('Excluded from CUDA tests')
3882
3883        with set_default_dtype(self.default_dtype):
3884            cpu_input = self._get_input()
3885            cpu_target = self._get_target()
3886            cpu_module = self.constructor(*self.constructor_args)
3887            gpu_module = self.constructor(*self.constructor_args)
3888
3889            # Convert input, target and module parameters to dtype
3890            cpu_input = convert_dtype(cpu_input, dtype, True)
3891            if cpu_target.is_floating_point() or cpu_target.is_complex():
3892                cpu_target = convert_dtype(cpu_target, dtype)
3893            cpu_module.type(dtype)
3894            gpu_module.type(dtype)
3895
3896            # GPU setup
3897            gpu_input = to_gpu(cpu_input)
3898            gpu_target = to_gpu(cpu_target)
3899            gpu_module.cuda()
3900
3901            # torch.HalfTensor doesn't support most operations, converting back to default
3902            if dtype in {torch.half, torch.bfloat16}:
3903                cpu_input = self._get_input()
3904                cpu_target = self._get_target()
3905                # Loss modules with weights require consistent input/module weight types
3906                cpu_module = self.constructor(*self.constructor_args)
3907
3908            cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
3909            gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
3910            # dtype used to be able to be None, so set precision in this way instead of a precision map
3911            test_case.assertEqual(cpu_output, gpu_output,
3912                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
3913
3914            cpu_gradInput = test_case._backward_criterion(
3915                cpu_module, cpu_input, cpu_output, cpu_target, extra_args=extra_args)
3916            gpu_gradInput = test_case._backward_criterion(
3917                gpu_module, gpu_input, gpu_output, gpu_target, extra_args=extra_args)
3918            # dtype used to be able to be None, so set precision in this way instead of a precision map
3919            test_case.assertEqual(cpu_gradInput, gpu_gradInput,
3920                                  atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0, exact_dtype=False)
3921
3922    def _get_target(self):
3923        return self._get_arg('target', False)
3924
3925    @property
3926    def constructor_args(self):
3927        return self._get_arg('constructor_args', False)
3928
3929    @property
3930    def extra_args(self):
3931        return self._get_arg('extra_args', False)
3932
3933
3934def _test_bfloat16_ops(test_case, op, device, inp_dims=(), prec=1e-2, scale_factor=None):
3935    # fp32 compute
3936    input1 = torch.randn(inp_dims, dtype=torch.float32, device=device, requires_grad=True)
3937    if scale_factor is not None:
3938        input1 = (torch.rand(inp_dims, dtype=torch.bfloat16, device=device) * scale_factor).float().requires_grad_()
3939    out1 = op(input1)
3940    grad_input1 = torch.randn_like(out1, device=device)
3941    out1.backward(grad_input1)
3942
3943    # bfloat16 compute
3944    op_bfp16 = op.bfloat16()
3945    input2 = input1.detach().bfloat16().requires_grad_()
3946    grad_input2 = grad_input1.bfloat16()
3947    out2 = op_bfp16(input2)
3948    out2.backward(grad_input2)
3949
3950    test_case.assertEqual(out1, out2, atol=prec, rtol=prec, exact_dtype=False)
3951    test_case.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=prec, exact_dtype=False)
3952
3953def _test_module_empty_input(test_case, module, inp, check_size=True, inference=False):
3954    if not inference:
3955        inp.requires_grad_(True)
3956    out = module(inp)
3957    if not inference:
3958        gO = torch.rand_like(out)
3959        out.backward(gO)
3960    if check_size:
3961        test_case.assertEqual(out.size(), inp.size())
3962    if not inference:
3963        for p in module.parameters():
3964            if p.requires_grad:
3965                test_case.assertEqual(p.grad, torch.zeros_like(p.grad))
3966        test_case.assertEqual(inp.grad, torch.zeros_like(inp))
3967
3968
3969def _create_basic_net():
3970    class Layer(nn.Module):
3971        def __init__(self) -> None:
3972            super().__init__()
3973            self.layer_dummy_param = nn.Parameter(torch.empty(3, 5))
3974            self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7))
3975
3976    class Net(nn.Module):
3977        def __init__(self) -> None:
3978            super().__init__()
3979            self.l1 = Layer()
3980            self.dummy_param = nn.Parameter(torch.empty(3, 5))
3981            self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1))
3982
3983    l = Layer()
3984    n = Net()
3985    s = nn.Sequential(n, n)
3986
3987    return l, n, s
3988