xref: /aosp_15_r20/external/pytorch/test/test_nn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport math
5*da0073e9SAndroid Build Coastguard Workerimport random
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerimport io
8*da0073e9SAndroid Build Coastguard Workerimport itertools
9*da0073e9SAndroid Build Coastguard Workerimport warnings
10*da0073e9SAndroid Build Coastguard Workerimport pickle
11*da0073e9SAndroid Build Coastguard Workerimport re
12*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
13*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
14*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
15*da0073e9SAndroid Build Coastguard Workerfrom collections import OrderedDict
16*da0073e9SAndroid Build Coastguard Workerfrom unittest import SkipTest
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerimport torch
19*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
20*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD
21*da0073e9SAndroid Build Coastguard Workerimport torch.backends.cudnn as cudnn
22*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
23*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
24*da0073e9SAndroid Build Coastguard Workerimport torch.nn.utils.rnn as rnn_utils
25*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils import clip_grad_norm_, clip_grad_value_
26*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils import parameters_to_vector, vector_to_parameters
27*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.fusion import fuse_conv_bn_weights
28*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.fusion import fuse_linear_bn_weights
29*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Buffer, Parameter
30*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.parallel._functions import Broadcast
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
33*da0073e9SAndroid Build Coastguard Worker    TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \
34*da0073e9SAndroid Build Coastguard Worker    download_file, get_function_arglist, load_tests, skipIfMps, \
35*da0073e9SAndroid Build Coastguard Worker    IS_PPC, \
36*da0073e9SAndroid Build Coastguard Worker    parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
37*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo, gcIfJetson, set_default_dtype
38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, PLATFORM_SUPPORTS_FLASH_ATTENTION
39*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
40*da0073e9SAndroid Build Coastguard Worker    module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
41*da0073e9SAndroid Build Coastguard Worker    ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \
43*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
44*da0073e9SAndroid Build Coastguard Worker    skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \
45*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \
46*da0073e9SAndroid Build Coastguard Worker    skipMeta, get_all_device_types
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import given
49*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.hypothesis_utils as hu
50*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck, \
51*da0073e9SAndroid Build Coastguard Worker    GRADCHECK_NONDET_TOL
52*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import dtype2prec_DONTUSE
53*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
54*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _TensorOrTensors
55*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_mkldnn import bf32_on_and_off
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard WorkerAMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker# load_tests from common_utils is used to automatically filter tests for
60*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings
61*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY:
64*da0073e9SAndroid Build Coastguard Worker    import scipy.signal
65*da0073e9SAndroid Build Coastguard Worker    import scipy.ndimage
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerif TEST_NUMPY:
68*da0073e9SAndroid Build Coastguard Worker    import numpy as np
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker# WARNING: If you add a new top-level test case to this file, you MUST
72*da0073e9SAndroid Build Coastguard Worker# update test/run_test.py to list it, otherwise it will NOT be run in
73*da0073e9SAndroid Build Coastguard Worker# CI.
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Workerclass TestNN(NNTestCase):
76*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
77*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    def _forward(self, module, input: _TensorOrTensors):
80*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
81*da0073e9SAndroid Build Coastguard Worker            if isinstance(input, tuple):
82*da0073e9SAndroid Build Coastguard Worker                return module(*input)
83*da0073e9SAndroid Build Coastguard Worker            else:
84*da0073e9SAndroid Build Coastguard Worker                return module(input)
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def _backward(self, module, input: _TensorOrTensors, output, grad_output, create_graph=False):
87*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_output, retain_graph=True, create_graph=create_graph)
88*da0073e9SAndroid Build Coastguard Worker        if isinstance(input, tuple):
89*da0073e9SAndroid Build Coastguard Worker            return tuple(i.grad.data if i.grad is not None else None for i in input)
90*da0073e9SAndroid Build Coastguard Worker        else:
91*da0073e9SAndroid Build Coastguard Worker            return input.grad.data if input.grad is not None else None
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    def _forward_criterion(self, criterion, input, target, extra_args=None):
94*da0073e9SAndroid Build Coastguard Worker        if extra_args is None:
95*da0073e9SAndroid Build Coastguard Worker            extra_args = ()
96*da0073e9SAndroid Build Coastguard Worker        if isinstance(input, tuple):
97*da0073e9SAndroid Build Coastguard Worker            args = input + (target,) + extra_args
98*da0073e9SAndroid Build Coastguard Worker            output = criterion(*args)
99*da0073e9SAndroid Build Coastguard Worker        else:
100*da0073e9SAndroid Build Coastguard Worker            output = criterion(input, target, *extra_args)
101*da0073e9SAndroid Build Coastguard Worker        return output
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    def _backward_criterion(self, criterion, input, output, target, gradOutput=None, extra_args=None):
104*da0073e9SAndroid Build Coastguard Worker        if extra_args is None:
105*da0073e9SAndroid Build Coastguard Worker            extra_args = ()
106*da0073e9SAndroid Build Coastguard Worker        input_tuple = input if isinstance(input, tuple) else (input,)
107*da0073e9SAndroid Build Coastguard Worker        output_tuple = output if isinstance(output, tuple) else (output,)
108*da0073e9SAndroid Build Coastguard Worker        for i in input_tuple:
109*da0073e9SAndroid Build Coastguard Worker            if i.grad is not None:
110*da0073e9SAndroid Build Coastguard Worker                i.grad.data.zero_()
111*da0073e9SAndroid Build Coastguard Worker        args = input_tuple + (target,) + extra_args
112*da0073e9SAndroid Build Coastguard Worker        if gradOutput is None:
113*da0073e9SAndroid Build Coastguard Worker            gradOutput = torch.ones(())
114*da0073e9SAndroid Build Coastguard Worker        criterion(*args).backward(gradOutput.to(output_tuple[0]))
115*da0073e9SAndroid Build Coastguard Worker        if isinstance(input, tuple):
116*da0073e9SAndroid Build Coastguard Worker            return tuple(i.grad.data for i in input)
117*da0073e9SAndroid Build Coastguard Worker        else:
118*da0073e9SAndroid Build Coastguard Worker            return input.grad.data
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def _zero_grad_parameters(self, module):
121*da0073e9SAndroid Build Coastguard Worker        for p in module.parameters():
122*da0073e9SAndroid Build Coastguard Worker            if p.grad is not None:
123*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
124*da0073e9SAndroid Build Coastguard Worker                    p.grad.zero_()
125*da0073e9SAndroid Build Coastguard Worker                p.grad.detach_()
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    def _get_parameters(self, module):
128*da0073e9SAndroid Build Coastguard Worker        params = []
129*da0073e9SAndroid Build Coastguard Worker        d_params = []
130*da0073e9SAndroid Build Coastguard Worker        for p in module.parameters():
131*da0073e9SAndroid Build Coastguard Worker            params.append(p)
132*da0073e9SAndroid Build Coastguard Worker            d_params.append(p.grad)
133*da0073e9SAndroid Build Coastguard Worker        return params, d_params
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def test_parse_to(self):
136*da0073e9SAndroid Build Coastguard Worker        # Test for buggy use of THPMemoryFormat_New
137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
138*da0073e9SAndroid Build Coastguard Worker            repr(torch._C._nn._parse_to(memory_format=torch.contiguous_format)[3]),
139*da0073e9SAndroid Build Coastguard Worker            "torch.contiguous_format"
140*da0073e9SAndroid Build Coastguard Worker        )
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_(self):
143*da0073e9SAndroid Build Coastguard Worker        m = _create_basic_net()[-1]
144*da0073e9SAndroid Build Coastguard Worker        assert len(list(m.buffers())) > 0, 'invalid test'
145*da0073e9SAndroid Build Coastguard Worker        assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
146*da0073e9SAndroid Build Coastguard Worker        assert len(list(m.parameters())) > 0, 'invalid test'
147*da0073e9SAndroid Build Coastguard Worker        assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
148*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (False, True):
149*da0073e9SAndroid Build Coastguard Worker            self.assertIs(m.requires_grad_(requires_grad), m)
150*da0073e9SAndroid Build Coastguard Worker            for p in m.parameters():
151*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p.requires_grad, requires_grad)
152*da0073e9SAndroid Build Coastguard Worker            for b in m.buffers():
153*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(b.requires_grad)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    def test_module_backcompat(self):
156*da0073e9SAndroid Build Coastguard Worker        from torch.serialization import SourceChangeWarning
157*da0073e9SAndroid Build Coastguard Worker        path = download_file('https://download.pytorch.org/test_data/linear.pt')
158*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings():
159*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter('ignore', SourceChangeWarning)
160*da0073e9SAndroid Build Coastguard Worker            # weights_only=False as this is legacy code that saves the model
161*da0073e9SAndroid Build Coastguard Worker            m = torch.load(path, weights_only=False)
162*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, dtype=torch.float)
163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(input).size(), (2, 5))
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    def test_module_super_init(self):
166*da0073e9SAndroid Build Coastguard Worker        class MyMixin:
167*da0073e9SAndroid Build Coastguard Worker            def __init__(self, *a, **kw):
168*da0073e9SAndroid Build Coastguard Worker                super().__init__(*a, **kw)
169*da0073e9SAndroid Build Coastguard Worker                self.mixin_init = True
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        class MyModuleWithMixinBefore(MyMixin, nn.Module):
172*da0073e9SAndroid Build Coastguard Worker            pass
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        class MyModuleWithMixinAfter(nn.Module, MyMixin):
175*da0073e9SAndroid Build Coastguard Worker            pass
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
178*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        nn.Module.call_super_init = True
181*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
182*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
183*da0073e9SAndroid Build Coastguard Worker        nn.Module.call_super_init = False
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        MyModuleWithMixinBefore.call_super_init = True
186*da0073e9SAndroid Build Coastguard Worker        MyModuleWithMixinAfter.call_super_init = True
187*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(MyModuleWithMixinBefore(), 'mixin_init'))
188*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(MyModuleWithMixinAfter(), 'mixin_init'))
189*da0073e9SAndroid Build Coastguard Worker        MyModuleWithMixinBefore.call_super_init = False
190*da0073e9SAndroid Build Coastguard Worker        MyModuleWithMixinAfter.call_super_init = False
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    def test_share_memory(self):
193*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
194*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
195*da0073e9SAndroid Build Coastguard Worker                super().__init__()
196*da0073e9SAndroid Build Coastguard Worker                self.p = nn.Parameter(torch.eye(5))
197*da0073e9SAndroid Build Coastguard Worker                self.par = nn.ParameterList()
198*da0073e9SAndroid Build Coastguard Worker                self.par.append(nn.Parameter(torch.randn(10)))
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
201*da0073e9SAndroid Build Coastguard Worker                # NB: dead code
202*da0073e9SAndroid Build Coastguard Worker                return inp.clone()
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        net = Net()
205*da0073e9SAndroid Build Coastguard Worker        for p in net.parameters():
206*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.storage().is_shared())
207*da0073e9SAndroid Build Coastguard Worker        for b in net.buffers():
208*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(b.storage().is_shared())
209*da0073e9SAndroid Build Coastguard Worker        net.share_memory()
210*da0073e9SAndroid Build Coastguard Worker        for p in net.parameters():
211*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(p.storage().is_shared())
212*da0073e9SAndroid Build Coastguard Worker        for b in net.buffers():
213*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(b.storage().is_shared())
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
216*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(3, 5)
217*da0073e9SAndroid Build Coastguard Worker        self.assertIs(m, m.to('cpu'))
218*da0073e9SAndroid Build Coastguard Worker        self.assertIs(m, m.to('cpu', dtype=torch.float32))
219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.double(), m.to(torch.float64))
220*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: m.to('cpu', copy=True))
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
223*da0073e9SAndroid Build Coastguard Worker            for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
224*da0073e9SAndroid Build Coastguard Worker                m2 = m.cuda(device=cuda)
225*da0073e9SAndroid Build Coastguard Worker                self.assertIs(m2, m2.to(cuda))
226*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(m, m2.to('cpu'))
227*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(m2, m.to(cuda))
228*da0073e9SAndroid Build Coastguard Worker                self.assertIs(m2, m2.to(dtype=torch.float32))
229*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(m2.double(), m2.to(dtype=torch.float64))
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    def test_zero_grad(self):
232*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(2, 5, requires_grad=True)
233*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(5, 5)
234*da0073e9SAndroid Build Coastguard Worker        for p in module.parameters():
235*da0073e9SAndroid Build Coastguard Worker            p.requires_grad = False
236*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        module.weight.requires_grad = True
239*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
240*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)  # uninitialized grad
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker        module(i).sum().backward()
243*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(module.weight.grad)
244*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
245*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
246*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        module.bias.requires_grad = True
249*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
250*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)
251*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.bias.grad)
252*da0073e9SAndroid Build Coastguard Worker        module(i).sum().backward()
253*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(module.weight.grad)
254*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(module.bias.grad)
255*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
256*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(module.bias.grad.data.abs().sum(), 0)
257*da0073e9SAndroid Build Coastguard Worker        module.zero_grad(set_to_none=False)   # Force set to zeros.
258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
262*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)
263*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.bias.grad)
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    def test_no_grad(self):
266*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.float, torch.double]:
267*da0073e9SAndroid Build Coastguard Worker            module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
268*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 2, 10, 10).to(dtype)
269*da0073e9SAndroid Build Coastguard Worker            x = input
270*da0073e9SAndroid Build Coastguard Worker            y = input.clone()
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker            output = module(x)
273*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output.requires_grad)
274*da0073e9SAndroid Build Coastguard Worker            output.backward(torch.ones(1, 5, 10, 10))
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
277*da0073e9SAndroid Build Coastguard Worker                output2 = module(y)
278*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(output2.requires_grad)
279*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    def test_parameters_and_named_parameters(self):
282*da0073e9SAndroid Build Coastguard Worker        def names(named_parameters):
283*da0073e9SAndroid Build Coastguard Worker            return [k for k, _ in named_parameters]
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        l, n, s = _create_basic_net()
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(l.parameters())), 1)
288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
289*da0073e9SAndroid Build Coastguard Worker            names(l.named_parameters()),
290*da0073e9SAndroid Build Coastguard Worker            ['layer_dummy_param'])
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(n.parameters())), 2)
293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
294*da0073e9SAndroid Build Coastguard Worker            names(n.named_parameters()),
295*da0073e9SAndroid Build Coastguard Worker            ['dummy_param', 'l1.layer_dummy_param'])
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(n.parameters(recurse=False))), 1)
298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
299*da0073e9SAndroid Build Coastguard Worker            names(n.named_parameters(recurse=False)),
300*da0073e9SAndroid Build Coastguard Worker            ['dummy_param'])
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(s.parameters())), 2)
303*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
304*da0073e9SAndroid Build Coastguard Worker            names(s.named_parameters()),
305*da0073e9SAndroid Build Coastguard Worker            ['0.dummy_param', '0.l1.layer_dummy_param'])
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker    def test_named_parameters_remove_duplicate(self):
308*da0073e9SAndroid Build Coastguard Worker        def names(named_parameters):
309*da0073e9SAndroid Build Coastguard Worker            return [k for k, _ in named_parameters]
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        class M1(nn.Module):
312*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
313*da0073e9SAndroid Build Coastguard Worker                super().__init__()
314*da0073e9SAndroid Build Coastguard Worker                self.param1 = nn.Parameter(torch.empty(3, 3))
315*da0073e9SAndroid Build Coastguard Worker                self.param2 = self.param1
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        m1 = M1()
318*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(names(m1.named_parameters()),
319*da0073e9SAndroid Build Coastguard Worker                         ["param1"])
320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(names(m1.named_parameters(remove_duplicate=False)),
321*da0073e9SAndroid Build Coastguard Worker                         ["param1", "param2"])
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        class M2(nn.Module):
324*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
325*da0073e9SAndroid Build Coastguard Worker                super().__init__()
326*da0073e9SAndroid Build Coastguard Worker                self.mod1 = nn.Linear(3, 4, bias=False)
327*da0073e9SAndroid Build Coastguard Worker                self.mod2 = self.mod1
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        m2 = M2()
330*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(names(m2.named_parameters()),
331*da0073e9SAndroid Build Coastguard Worker                         ["mod1.weight"])
332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(names(m2.named_parameters(remove_duplicate=False)),
333*da0073e9SAndroid Build Coastguard Worker                         ["mod1.weight", "mod2.weight"])
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    def test_buffers_and_named_buffers(self):
336*da0073e9SAndroid Build Coastguard Worker        def names(named_buffers):
337*da0073e9SAndroid Build Coastguard Worker            return [k for k, _ in named_buffers]
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        l, n, s = _create_basic_net()
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(l.buffers())), 1)
342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
343*da0073e9SAndroid Build Coastguard Worker            names(l.named_buffers()),
344*da0073e9SAndroid Build Coastguard Worker            ['layer_dummy_buf'])
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(n.buffers())), 2)
347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
348*da0073e9SAndroid Build Coastguard Worker            names(n.named_buffers()),
349*da0073e9SAndroid Build Coastguard Worker            ['dummy_buf', 'l1.layer_dummy_buf'])
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(n.buffers(recurse=False))), 1)
352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
353*da0073e9SAndroid Build Coastguard Worker            names(n.named_buffers(recurse=False)),
354*da0073e9SAndroid Build Coastguard Worker            ['dummy_buf'])
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(s.buffers())), 2)
357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
358*da0073e9SAndroid Build Coastguard Worker            names(s.named_buffers()),
359*da0073e9SAndroid Build Coastguard Worker            ['0.dummy_buf', '0.l1.layer_dummy_buf'])
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        # test remove_duplicate
362*da0073e9SAndroid Build Coastguard Worker        class M(nn.Module):
363*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
364*da0073e9SAndroid Build Coastguard Worker                super().__init__()
365*da0073e9SAndroid Build Coastguard Worker                self.buffer1 = Buffer(torch.empty(3, 5))
366*da0073e9SAndroid Build Coastguard Worker                self.buffer2 = self.buffer1
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        m = M()
369*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(names(m.named_buffers()),
370*da0073e9SAndroid Build Coastguard Worker                         ["buffer1"])
371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
372*da0073e9SAndroid Build Coastguard Worker                         ["buffer1", "buffer2"])
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker    def test_buffer_bad_module_subclass(self):
375*da0073e9SAndroid Build Coastguard Worker        class MyBadModule(nn.Linear):
376*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
377*da0073e9SAndroid Build Coastguard Worker                super().__init__(2, 2)
378*da0073e9SAndroid Build Coastguard Worker                self.bar = Buffer(torch.rand(2, 2))
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker            def register_buffer(self, name, value):
381*da0073e9SAndroid Build Coastguard Worker                # persistent is explicitly missing!
382*da0073e9SAndroid Build Coastguard Worker                super().register_buffer(name, value, True)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        foo = MyBadModule()
385*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(foo.bar)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    def test_call_supports_python_dict_output(self):
388*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
389*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
390*da0073e9SAndroid Build Coastguard Worker                super().__init__()
391*da0073e9SAndroid Build Coastguard Worker                self.l1 = nn.Linear(10, 20)
392*da0073e9SAndroid Build Coastguard Worker                self.register_backward_hook(self.hook)
393*da0073e9SAndroid Build Coastguard Worker                self.check_backward_hook_flag = False
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker            def hook(self, module, grad_out, grad_in):
396*da0073e9SAndroid Build Coastguard Worker                self.check_backward_hook_flag = True
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker            def forward(self, inputs):
399*da0073e9SAndroid Build Coastguard Worker                return {"output": self.l1(inputs).sum()}
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        net = Net()
402*da0073e9SAndroid Build Coastguard Worker        model_output = net(torch.randn([5, 10]))
403*da0073e9SAndroid Build Coastguard Worker        model_output["output"].backward()
404*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(net.check_backward_hook_flag)
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    def test_children(self):
407*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(2, 2)
408*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 2)
409*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(2, 2)
410*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(2, 2)
411*da0073e9SAndroid Build Coastguard Worker        subnet = nn.Sequential(l3, l4)
412*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential(l1, l2, l1, l2, subnet)
413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(s.children()), [l1, l2, subnet])
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    def test_train_errors_for_invalid_mode(self):
416*da0073e9SAndroid Build Coastguard Worker        class SubclassNet(nn.Module):
417*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
418*da0073e9SAndroid Build Coastguard Worker                super().__init__()
419*da0073e9SAndroid Build Coastguard Worker                self.l1 = nn.Linear(2, 2)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker            def forward(self, inputs):
422*da0073e9SAndroid Build Coastguard Worker                return self.l1(inputs)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        subclass_net = SubclassNet()
425*da0073e9SAndroid Build Coastguard Worker        sequential_net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        error_modes = ["invalid_str", torch.device('cpu')]
428*da0073e9SAndroid Build Coastguard Worker        modules_to_check = [subclass_net, sequential_net]
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        for error_mode, module in itertools.product(error_modes, modules_to_check):
431*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
432*da0073e9SAndroid Build Coastguard Worker                module.train(error_mode)
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    def test_dir(self):
435*da0073e9SAndroid Build Coastguard Worker        linear = nn.Linear(2, 2)
436*da0073e9SAndroid Build Coastguard Worker        linear._test_submodule = nn.Linear(2, 2)
437*da0073e9SAndroid Build Coastguard Worker        linear._test_parameter = Parameter(torch.empty(2, 2))
438*da0073e9SAndroid Build Coastguard Worker        linear._test_buffer = Buffer(torch.empty(2, 2))
439*da0073e9SAndroid Build Coastguard Worker        keys = dir(linear)
440*da0073e9SAndroid Build Coastguard Worker        self.assertIn('_test_submodule', keys)
441*da0073e9SAndroid Build Coastguard Worker        self.assertIn('_test_parameter', keys)
442*da0073e9SAndroid Build Coastguard Worker        self.assertIn('_test_buffer', keys)
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        for key in keys:
445*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(linear, key))
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker    def test_repr(self):
448*da0073e9SAndroid Build Coastguard Worker        # no extra information or sub-modules
449*da0073e9SAndroid Build Coastguard Worker        empty_sequential = nn.Sequential()
450*da0073e9SAndroid Build Coastguard Worker        expected_repr_empty = 'Sequential()'
451*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(empty_sequential), expected_repr_empty)
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker        # one liner extra information
454*da0073e9SAndroid Build Coastguard Worker        linear = nn.Linear(1, 1)
455*da0073e9SAndroid Build Coastguard Worker        expected_repr_linear = 'Linear(in_features=1, out_features=1, bias=True)'
456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(linear), expected_repr_linear)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        # sub-modules repr
459*da0073e9SAndroid Build Coastguard Worker        sequential = nn.Sequential(linear)
460*da0073e9SAndroid Build Coastguard Worker        expected_repr_sequential = 'Sequential(\n' \
461*da0073e9SAndroid Build Coastguard Worker            '  (0): Linear(in_features=1, out_features=1, bias=True)\n' \
462*da0073e9SAndroid Build Coastguard Worker            ')'
463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(sequential), expected_repr_sequential)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    def test_dir_digit(self):
466*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(nn.Linear(2, 2))
467*da0073e9SAndroid Build Coastguard Worker        keys = dir(model)
468*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn('0', keys)
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    def test_named_children(self):
471*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(2, 2)
472*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 2)
473*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(2, 2)
474*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(2, 2)
475*da0073e9SAndroid Build Coastguard Worker        subnet = nn.Sequential(l3, l4)
476*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential()
477*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
478*da0073e9SAndroid Build Coastguard Worker            s.add_module('', l1)
479*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
480*da0073e9SAndroid Build Coastguard Worker            s.add_module('name.with.dot', l1)
481*da0073e9SAndroid Build Coastguard Worker        s.add_module('layer1', l1)
482*da0073e9SAndroid Build Coastguard Worker        s.add_module('layer2', l2)
483*da0073e9SAndroid Build Coastguard Worker        s.add_module('layer3', l1)
484*da0073e9SAndroid Build Coastguard Worker        s.add_module('layer4', l2)
485*da0073e9SAndroid Build Coastguard Worker        s.add_module('subnet', subnet)
486*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(s.named_children()), [('layer1', l1), ('layer2', l2), ('subnet', subnet)])
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker    def test_modules(self):
489*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
490*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
491*da0073e9SAndroid Build Coastguard Worker                super().__init__()
492*da0073e9SAndroid Build Coastguard Worker                self.l1 = l
493*da0073e9SAndroid Build Coastguard Worker                self.l2 = l
494*da0073e9SAndroid Build Coastguard Worker                self.param = torch.empty(3, 5)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(10, 20)
497*da0073e9SAndroid Build Coastguard Worker        n = Net()
498*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential(n, n, n, n)
499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(s.modules()), [s, n, l])
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    def test_named_modules(self):
502*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
503*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
504*da0073e9SAndroid Build Coastguard Worker                super().__init__()
505*da0073e9SAndroid Build Coastguard Worker                self.l1 = l
506*da0073e9SAndroid Build Coastguard Worker                self.l2 = l
507*da0073e9SAndroid Build Coastguard Worker                self.param = torch.empty(3, 5)
508*da0073e9SAndroid Build Coastguard Worker                self.block = block
509*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(10, 20)
510*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
511*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(10, 20)
512*da0073e9SAndroid Build Coastguard Worker        block = nn.Sequential()
513*da0073e9SAndroid Build Coastguard Worker        block.add_module('linear1', l1)
514*da0073e9SAndroid Build Coastguard Worker        block.add_module('linear2', l2)
515*da0073e9SAndroid Build Coastguard Worker        n = Net()
516*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential(n, n)
517*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(s.named_modules()), [('', s), ('0', n), ('0.l1', l),
518*da0073e9SAndroid Build Coastguard Worker                                                   ('0.block', block), ('0.block.linear1', l1),
519*da0073e9SAndroid Build Coastguard Worker                                                   ('0.block.linear2', l2)])
520*da0073e9SAndroid Build Coastguard Worker        # test the option to not remove duplicate module instances
521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(s.named_modules(remove_duplicate=False)), [
522*da0073e9SAndroid Build Coastguard Worker            ('', s), ('0', n), ('0.l1', l), ('0.l2', l),
523*da0073e9SAndroid Build Coastguard Worker            ('0.block', block), ('0.block.linear1', l1),
524*da0073e9SAndroid Build Coastguard Worker            ('0.block.linear2', l2),
525*da0073e9SAndroid Build Coastguard Worker            ('1', n), ('1.l1', l), ('1.l2', l),
526*da0073e9SAndroid Build Coastguard Worker            ('1.block', block), ('1.block.linear1', l1),
527*da0073e9SAndroid Build Coastguard Worker            ('1.block.linear2', l2)])
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    def test_register_buffer_raises_error_if_name_is_not_string(self):
530*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
531*da0073e9SAndroid Build Coastguard Worker        expected_error = 'buffer name should be a string. Got '
532*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, expected_error + 'int'):
533*da0073e9SAndroid Build Coastguard Worker            m.register_buffer(1, torch.rand(5))
534*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, expected_error + 'NoneType'):
535*da0073e9SAndroid Build Coastguard Worker            m.register_buffer(None, torch.rand(5))
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker    def test_register_buffer_raises_error_if_attr_exists(self):
538*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
539*da0073e9SAndroid Build Coastguard Worker        m.attribute_name = 5
540*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
541*da0073e9SAndroid Build Coastguard Worker            m.register_buffer('attribute_name', torch.rand(5))
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
544*da0073e9SAndroid Build Coastguard Worker            m.attribute_name = Buffer(torch.rand(5))
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        del m.attribute_name
547*da0073e9SAndroid Build Coastguard Worker        m.register_parameter('attribute_name', nn.Parameter())
548*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
549*da0073e9SAndroid Build Coastguard Worker            m.register_buffer('attribute_name', torch.rand(5))
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker        del m.attribute_name
552*da0073e9SAndroid Build Coastguard Worker        m.add_module('attribute_name', nn.Module())
553*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
554*da0073e9SAndroid Build Coastguard Worker            m.register_buffer('attribute_name', torch.rand(5))
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    def test_register_buffer_raises_error_if_not_tensor(self):
557*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
558*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
559*da0073e9SAndroid Build Coastguard Worker            m.register_buffer('attribute_name', 5)
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker    def test_register_buffer_allows_overwriting_with_same_name(self):
562*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
563*da0073e9SAndroid Build Coastguard Worker        buffer1 = torch.rand(5)
564*da0073e9SAndroid Build Coastguard Worker        buffer2 = buffer1 + 5
565*da0073e9SAndroid Build Coastguard Worker        buffer3 = None
566*da0073e9SAndroid Build Coastguard Worker        m.register_buffer('buffer_name', buffer1)
567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer_name, buffer1)
568*da0073e9SAndroid Build Coastguard Worker        m.register_buffer('buffer_name', buffer2)
569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer_name, buffer2)
570*da0073e9SAndroid Build Coastguard Worker        m.register_buffer('buffer_name', buffer3)
571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer_name, buffer3)
572*da0073e9SAndroid Build Coastguard Worker        m.buffer_name = Buffer(buffer1)
573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer_name, Buffer(buffer1))
574*da0073e9SAndroid Build Coastguard Worker        m.buffer_name = Buffer(buffer2)
575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer_name, Buffer(buffer2))
576*da0073e9SAndroid Build Coastguard Worker        m.buffer_name = Buffer(buffer3)
577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer_name, Buffer(buffer3))
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker    def test_get_buffer(self):
580*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
581*da0073e9SAndroid Build Coastguard Worker        buffer1 = torch.randn(2, 3)
582*da0073e9SAndroid Build Coastguard Worker        buffer2 = torch.randn(4, 5)
583*da0073e9SAndroid Build Coastguard Worker        m.foo = Buffer(buffer1)
584*da0073e9SAndroid Build Coastguard Worker        m.register_buffer('bar', buffer2)
585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffer1, m.get_buffer('foo'))
586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(buffer2, m.get_buffer('bar'))
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    def test_get_buffer_from_submodules(self):
589*da0073e9SAndroid Build Coastguard Worker        class MyModule(nn.Module):
590*da0073e9SAndroid Build Coastguard Worker            def __init__(self, foo, bar):
591*da0073e9SAndroid Build Coastguard Worker                super().__init__()
592*da0073e9SAndroid Build Coastguard Worker                self.sub = Sub(foo, bar)
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker        class Sub(nn.Module):
595*da0073e9SAndroid Build Coastguard Worker            def __init__(self, foo, bar):
596*da0073e9SAndroid Build Coastguard Worker                super().__init__()
597*da0073e9SAndroid Build Coastguard Worker                self.foo = Buffer(foo)
598*da0073e9SAndroid Build Coastguard Worker                self.subsub = SubSub(bar)
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker        class SubSub(nn.Module):
601*da0073e9SAndroid Build Coastguard Worker            def __init__(self, bar):
602*da0073e9SAndroid Build Coastguard Worker                super().__init__()
603*da0073e9SAndroid Build Coastguard Worker                self.bar = Buffer(bar)
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        foo = torch.randn(2, 3)
606*da0073e9SAndroid Build Coastguard Worker        bar = torch.randn(4, 5)
607*da0073e9SAndroid Build Coastguard Worker        m = MyModule(foo, bar)
608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo, m.get_buffer('sub.foo'))
609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bar, m.get_buffer('sub.subsub.bar'))
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker    def test_buffer_not_persistent(self):
612*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
613*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Buffer(torch.rand(5), persistent=False)
614*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 1)
615*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(m.state_dict()) == 0)
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker    def test_buffer_not_persistent_del(self):
618*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
619*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Buffer(torch.rand(5), persistent=False)
620*da0073e9SAndroid Build Coastguard Worker        del m.buf
621*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 0)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    def test_buffer_not_persistent_overwrite(self):
624*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
625*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Buffer(torch.rand(5), persistent=False)
626*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Buffer(torch.rand(5))
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker        # can we overwrite a non-persistent buffer with a persistent one?
629*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 1)
630*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(m.state_dict()) == 1)
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        # can we overwrite a persistent buffer with a non-persistent one?
633*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Buffer(torch.rand(5), persistent=False)
634*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 1)
635*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(m.state_dict()) == 0)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    def test_buffer_not_persistent_assign(self):
638*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
639*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Buffer(torch.rand(5), persistent=False)
640*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 1)
641*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(m.state_dict()) == 0)
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        # Assigning None removes the buffer but if we then assign a new Tensor
644*da0073e9SAndroid Build Coastguard Worker        # to the same property, it should still be marked as a buffer.
645*da0073e9SAndroid Build Coastguard Worker        m.buf = None
646*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 0)
647*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(m.state_dict()) == 0)
648*da0073e9SAndroid Build Coastguard Worker        m.buf = torch.rand(5)
649*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 1)
650*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(m.state_dict()) == 0)
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        # Assigning a Parameter removes the buffer.
653*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Parameter(torch.rand(5))
654*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(m.buffers())) == 0)
655*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(m.state_dict()) == 1)
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker    def test_buffer_not_persistent_load(self):
658*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
659*da0073e9SAndroid Build Coastguard Worker        m.buf = nn.Buffer(torch.rand(5), persistent=False)
660*da0073e9SAndroid Build Coastguard Worker        m.load_state_dict({})
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker    def test_register_parameter_raises_error_if_name_is_not_string(self):
663*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
664*da0073e9SAndroid Build Coastguard Worker        expected_error = 'parameter name should be a string. Got '
665*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, expected_error + 'int'):
666*da0073e9SAndroid Build Coastguard Worker            m.register_parameter(1, nn.Parameter())
667*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, expected_error + 'NoneType'):
668*da0073e9SAndroid Build Coastguard Worker            m.register_parameter(None, nn.Parameter())
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker    def test_register_parameter_raises_error_if_attr_exists(self):
671*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
672*da0073e9SAndroid Build Coastguard Worker        m.attribute_name = 5
673*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
674*da0073e9SAndroid Build Coastguard Worker            m.register_parameter('attribute_name', nn.Parameter())
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker        del m.attribute_name
677*da0073e9SAndroid Build Coastguard Worker        m.register_buffer('attribute_name', torch.rand(5))
678*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
679*da0073e9SAndroid Build Coastguard Worker            m.register_parameter('attribute_name', nn.Parameter())
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        del m.attribute_name
682*da0073e9SAndroid Build Coastguard Worker        m.attribute_name = Buffer(torch.rand(5))
683*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
684*da0073e9SAndroid Build Coastguard Worker            m.register_parameter('attribute_name', nn.Parameter())
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker        del m.attribute_name
687*da0073e9SAndroid Build Coastguard Worker        m.add_module('attribute_name', nn.Module())
688*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(KeyError):
689*da0073e9SAndroid Build Coastguard Worker            m.register_parameter('attribute_name', nn.Parameter())
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker    def test_register_parameter_allows_overwriting_with_same_name(self):
692*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
693*da0073e9SAndroid Build Coastguard Worker        param1 = nn.Parameter(torch.rand(5))
694*da0073e9SAndroid Build Coastguard Worker        param2 = nn.Parameter(param1.data + 5)
695*da0073e9SAndroid Build Coastguard Worker        param3 = None
696*da0073e9SAndroid Build Coastguard Worker        m.register_parameter('param_name', param1)
697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.param_name, param1)
698*da0073e9SAndroid Build Coastguard Worker        m.register_parameter('param_name', param2)
699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.param_name, param2)
700*da0073e9SAndroid Build Coastguard Worker        m.register_parameter('param_name', param3)
701*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.param_name, param3)
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker    def test_add_module_raises_error_if_attr_exists(self):
704*da0073e9SAndroid Build Coastguard Worker        methods_to_test = ['add_module', 'register_module']
705*da0073e9SAndroid Build Coastguard Worker        for fn in methods_to_test:
706*da0073e9SAndroid Build Coastguard Worker            m = nn.Module()
707*da0073e9SAndroid Build Coastguard Worker            m.attribute_name = 5
708*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(KeyError):
709*da0073e9SAndroid Build Coastguard Worker                getattr(m, fn)('attribute_name', nn.Module())
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker            del m.attribute_name
712*da0073e9SAndroid Build Coastguard Worker            m.register_buffer('attribute_name', torch.rand(5))
713*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(KeyError):
714*da0073e9SAndroid Build Coastguard Worker                getattr(m, fn)('attribute_name', nn.Module())
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker            del m.attribute_name
717*da0073e9SAndroid Build Coastguard Worker            m.register_parameter('attribute_name', nn.Parameter())
718*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(KeyError):
719*da0073e9SAndroid Build Coastguard Worker                getattr(m, fn)('attribute_name', nn.Module())
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
722*da0073e9SAndroid Build Coastguard Worker    def test_getattr_with_property(self):
723*da0073e9SAndroid Build Coastguard Worker        class Model(nn.Module):
724*da0073e9SAndroid Build Coastguard Worker            @property
725*da0073e9SAndroid Build Coastguard Worker            def some_property(self):
726*da0073e9SAndroid Build Coastguard Worker                return self.something_that_doesnt_exist
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker        model = Model()
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
731*da0073e9SAndroid Build Coastguard Worker                AttributeError,
732*da0073e9SAndroid Build Coastguard Worker                r"'Model' object has no attribute 'something_that_doesnt_exist'"):
733*da0073e9SAndroid Build Coastguard Worker            model.some_property
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_getitem(self):
736*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
737*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
738*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
739*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
740*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3, l4)
741*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[0], l1)
742*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[1], l2)
743*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[2], l3)
744*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[3], l4)
745*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[torch.tensor(3, dtype=torch.int64)], l4)
746*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n[1:], nn.Sequential(l2, l3, l4))
747*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n[3:], nn.Sequential(l4))
748*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n[:-1], nn.Sequential(l1, l2, l3))
749*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n[:-3], nn.Sequential(l1))
750*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n[::-1], nn.Sequential(l4, l3, l2, l1))
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_setitem(self):
753*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
754*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
755*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
756*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
757*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3)
758*da0073e9SAndroid Build Coastguard Worker        n[0] = l4
759*da0073e9SAndroid Build Coastguard Worker        n[-1] = l4
760*da0073e9SAndroid Build Coastguard Worker        n[torch.tensor(1, dtype=torch.int16)] = l1
761*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[0], l4)
762*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[1], l1)
763*da0073e9SAndroid Build Coastguard Worker        self.assertIs(n[2], l4)
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_setitem_named(self):
766*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
767*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
768*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
769*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
770*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(OrderedDict([
771*da0073e9SAndroid Build Coastguard Worker            ('linear1', l1),
772*da0073e9SAndroid Build Coastguard Worker            ('linear2', l2),
773*da0073e9SAndroid Build Coastguard Worker            ('linear3', l3),
774*da0073e9SAndroid Build Coastguard Worker        ]))
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        n[0] = l4
777*da0073e9SAndroid Build Coastguard Worker        n[-1] = l4
778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n.linear1, l4)
779*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n.linear3, l4)
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_delitem(self):
782*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
783*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
784*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
785*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
786*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3, l4)
787*da0073e9SAndroid Build Coastguard Worker        del n[-1]
788*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n, nn.Sequential(l1, l2, l3))
789*da0073e9SAndroid Build Coastguard Worker        del n[1::2]
790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n, nn.Sequential(l1, l3))
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_add(self):
793*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(1, 2)
794*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 3)
795*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(3, 4)
796*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(4, 5)
797*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2)
798*da0073e9SAndroid Build Coastguard Worker        other = nn.Sequential(l3, l4)
799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n + other, nn.Sequential(l1, l2, l3, l4))
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_iadd(self):
802*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
803*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
804*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
805*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
806*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3)
807*da0073e9SAndroid Build Coastguard Worker        n2 = nn.Sequential(l4)
808*da0073e9SAndroid Build Coastguard Worker        n += n2
809*da0073e9SAndroid Build Coastguard Worker        n2 += n
810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n, nn.Sequential(l1, l2, l3, l4))
811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n2, nn.Sequential(l4, l1, l2, l3, l4))
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_mul(self):
814*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
815*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
816*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
817*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
818*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3, l4)
819*da0073e9SAndroid Build Coastguard Worker        n2 = n * 2
820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_rmul(self):
823*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
824*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
825*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
826*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
827*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3, l4)
828*da0073e9SAndroid Build Coastguard Worker        n2 = 2 * n
829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_imul(self):
832*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
833*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
834*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
835*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
836*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3, l4)
837*da0073e9SAndroid Build Coastguard Worker        n *= 2
838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n, nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4))
839*da0073e9SAndroid Build Coastguard Worker        n *= 2
840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
841*da0073e9SAndroid Build Coastguard Worker            n,
842*da0073e9SAndroid Build Coastguard Worker            nn.Sequential(l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4, l1, l2, l3, l4)
843*da0073e9SAndroid Build Coastguard Worker        )
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_append(self):
846*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
847*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
848*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
849*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
850*da0073e9SAndroid Build Coastguard Worker        n = nn.Sequential(l1, l2, l3)
851*da0073e9SAndroid Build Coastguard Worker        n2 = n.append(l4)
852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n, nn.Sequential(l1, l2, l3, l4))
853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4))
854*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4))
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_pop(self):
857*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(1, 2)
858*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 3)
859*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(3, 4)
860*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(4, 5)
861*da0073e9SAndroid Build Coastguard Worker        n1 = nn.Sequential(l1, l2, l3, l4)
862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l4, n1.pop(3))
863*da0073e9SAndroid Build Coastguard Worker        n2 = nn.Sequential(l1, l2, l3)
864*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n1, n2)
865*da0073e9SAndroid Build Coastguard Worker        # check order of the index
866*da0073e9SAndroid Build Coastguard Worker        for k, mod in zip(range(len(n1)), n1):
867*da0073e9SAndroid Build Coastguard Worker            self.assertIs(n1[k], mod)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_insert(self):
870*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(1, 2)
871*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 3)
872*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(3, 4)
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker        n1 = nn.Sequential(l1, l2, l3)
875*da0073e9SAndroid Build Coastguard Worker        module_1 = nn.Linear(4, 5)
876*da0073e9SAndroid Build Coastguard Worker        n2 = nn.Sequential(l1, module_1, l2, l3)
877*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n1.insert(1, module_1), n2)
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker        # test for negative support
880*da0073e9SAndroid Build Coastguard Worker        n3 = nn.Sequential(l1, l2, l3)
881*da0073e9SAndroid Build Coastguard Worker        module_2 = nn.Linear(5, 6)
882*da0073e9SAndroid Build Coastguard Worker        n4 = nn.Sequential(l1, module_2, l2, l3)
883*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n3.insert(-2, module_2), n4)
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_insert_fail_case(self):
886*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(1, 2)
887*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 3)
888*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(3, 4)
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(5, 6)
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        # test for error case
893*da0073e9SAndroid Build Coastguard Worker        n1 = nn.Sequential(l1, l2, l3)
894*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
895*da0073e9SAndroid Build Coastguard Worker            n1.insert(-5, module)
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
898*da0073e9SAndroid Build Coastguard Worker            n1.insert(1, [nn.Linear(6, 7)])
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    def test_Sequential_extend(self):
901*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 20)
902*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(20, 30)
903*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(30, 40)
904*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(40, 50)
905*da0073e9SAndroid Build Coastguard Worker        n1 = nn.Sequential(l1, l2)
906*da0073e9SAndroid Build Coastguard Worker        n2 = nn.Sequential(l3, l4)
907*da0073e9SAndroid Build Coastguard Worker        n3 = nn.Sequential(l1, l2)
908*da0073e9SAndroid Build Coastguard Worker        for l in n2:
909*da0073e9SAndroid Build Coastguard Worker            n1.append(l)
910*da0073e9SAndroid Build Coastguard Worker        n3.extend(n2)
911*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n3, n1)
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker    def test_ModuleList(self):
914*da0073e9SAndroid Build Coastguard Worker        modules = [nn.ReLU(), nn.Linear(5, 5)]
915*da0073e9SAndroid Build Coastguard Worker        module_list = nn.ModuleList(modules)
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker        def check():
918*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(module_list), len(modules))
919*da0073e9SAndroid Build Coastguard Worker            for m1, m2 in zip(modules, module_list):
920*da0073e9SAndroid Build Coastguard Worker                self.assertIs(m1, m2)
921*da0073e9SAndroid Build Coastguard Worker            for m1, m2 in zip(modules, module_list.children()):
922*da0073e9SAndroid Build Coastguard Worker                self.assertIs(m1, m2)
923*da0073e9SAndroid Build Coastguard Worker            for i in range(len(modules)):
924*da0073e9SAndroid Build Coastguard Worker                self.assertIs(module_list[i], modules[i])
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker        check()
927*da0073e9SAndroid Build Coastguard Worker        modules += [nn.Conv2d(3, 4, 3)]
928*da0073e9SAndroid Build Coastguard Worker        module_list += [modules[-1]]
929*da0073e9SAndroid Build Coastguard Worker        check()
930*da0073e9SAndroid Build Coastguard Worker        modules = modules + [nn.Conv2d(3, 4, 3, bias=False), nn.GELU()]
931*da0073e9SAndroid Build Coastguard Worker        module_list = module_list + nn.ModuleList(modules[-2:])
932*da0073e9SAndroid Build Coastguard Worker        check()
933*da0073e9SAndroid Build Coastguard Worker        modules.insert(1, nn.Linear(3, 2))
934*da0073e9SAndroid Build Coastguard Worker        module_list.insert(1, modules[1])
935*da0073e9SAndroid Build Coastguard Worker        check()
936*da0073e9SAndroid Build Coastguard Worker        modules.append(nn.Tanh())
937*da0073e9SAndroid Build Coastguard Worker        module_list.append(modules[-1])
938*da0073e9SAndroid Build Coastguard Worker        check()
939*da0073e9SAndroid Build Coastguard Worker        next_modules = [nn.Linear(5, 5), nn.Sigmoid()]
940*da0073e9SAndroid Build Coastguard Worker        modules.extend(next_modules)
941*da0073e9SAndroid Build Coastguard Worker        module_list.extend(next_modules)
942*da0073e9SAndroid Build Coastguard Worker        check()
943*da0073e9SAndroid Build Coastguard Worker        modules[2] = nn.Conv2d(5, 3, 2)
944*da0073e9SAndroid Build Coastguard Worker        module_list[2] = modules[2]
945*da0073e9SAndroid Build Coastguard Worker        check()
946*da0073e9SAndroid Build Coastguard Worker        modules[-1] = nn.Conv2d(5, 2, 1)
947*da0073e9SAndroid Build Coastguard Worker        module_list[-1] = modules[-1]
948*da0073e9SAndroid Build Coastguard Worker        check()
949*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor(2, dtype=torch.int32)
950*da0073e9SAndroid Build Coastguard Worker        modules[2] = nn.Conv2d(5, 3, 2)
951*da0073e9SAndroid Build Coastguard Worker        module_list[idx] = modules[2]
952*da0073e9SAndroid Build Coastguard Worker        self.assertIs(module_list[idx], modules[2])
953*da0073e9SAndroid Build Coastguard Worker        check()
954*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_list[1:], nn.ModuleList(modules[1:]))
955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_list[3:], nn.ModuleList(modules[3:]))
956*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_list[:-1], nn.ModuleList(modules[:-1]))
957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_list[:-3], nn.ModuleList(modules[:-3]))
958*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_list[::-1], nn.ModuleList(modules[::-1]))
959*da0073e9SAndroid Build Coastguard Worker        del module_list[-1]
960*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_list, nn.ModuleList(modules[:-1]))
961*da0073e9SAndroid Build Coastguard Worker        del module_list[1::2]
962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_list, nn.ModuleList(modules[:-1][0::2]))
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
965*da0073e9SAndroid Build Coastguard Worker            module_list += nn.ReLU()
966*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
967*da0073e9SAndroid Build Coastguard Worker            module_list.extend(nn.ReLU())
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(1, 2)
970*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 3)
971*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(3, 2)
972*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(2, 3)
973*da0073e9SAndroid Build Coastguard Worker        subnet = nn.Sequential(l3, l4)
974*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential(
975*da0073e9SAndroid Build Coastguard Worker            OrderedDict([
976*da0073e9SAndroid Build Coastguard Worker                ("layer1", l1),
977*da0073e9SAndroid Build Coastguard Worker                ("layer2", l2),
978*da0073e9SAndroid Build Coastguard Worker                ("layer3", l3),
979*da0073e9SAndroid Build Coastguard Worker                ("layer4", l4),
980*da0073e9SAndroid Build Coastguard Worker                ("subnet_layer", subnet)
981*da0073e9SAndroid Build Coastguard Worker            ])
982*da0073e9SAndroid Build Coastguard Worker        )
983*da0073e9SAndroid Build Coastguard Worker        modules = list(s.modules())
984*da0073e9SAndroid Build Coastguard Worker        module_list = nn.ModuleList()
985*da0073e9SAndroid Build Coastguard Worker        module_list.extend(s.modules())
986*da0073e9SAndroid Build Coastguard Worker        check()
987*da0073e9SAndroid Build Coastguard Worker
988*da0073e9SAndroid Build Coastguard Worker        modules = [nn.ReLU(), nn.Linear(5, 5), nn.Conv2d(3, 4, 3)]
989*da0073e9SAndroid Build Coastguard Worker        module_list = nn.ModuleList(modules)
990*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(modules.pop(1), module_list.pop(1))
991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(modules, module_list)
992*da0073e9SAndroid Build Coastguard Worker        # check order of the index
993*da0073e9SAndroid Build Coastguard Worker        for k, mod in zip(range(len(module_list)), module_list):
994*da0073e9SAndroid Build Coastguard Worker            self.assertIs(module_list[k], mod)
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker        # verify the right exception is thrown when trying to "forward" through a ModuleList
997*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, module_list)
998*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, module_list, torch.rand(1, 3))
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker    def test_ModuleDict(self):
1001*da0073e9SAndroid Build Coastguard Worker        modules = OrderedDict([
1002*da0073e9SAndroid Build Coastguard Worker            ('act', nn.ReLU()),
1003*da0073e9SAndroid Build Coastguard Worker            ('conv', nn.Conv2d(10, 10, 5)),
1004*da0073e9SAndroid Build Coastguard Worker            ('fc', nn.Linear(5, 5)),
1005*da0073e9SAndroid Build Coastguard Worker        ])
1006*da0073e9SAndroid Build Coastguard Worker
1007*da0073e9SAndroid Build Coastguard Worker        module_dict = nn.ModuleDict(modules)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker        def check():
1010*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(module_dict), len(modules))
1011*da0073e9SAndroid Build Coastguard Worker            for k1, m2 in zip(modules, module_dict.children()):
1012*da0073e9SAndroid Build Coastguard Worker                self.assertIs(modules[k1], m2)
1013*da0073e9SAndroid Build Coastguard Worker            for k1, k2 in zip(modules, module_dict):
1014*da0073e9SAndroid Build Coastguard Worker                self.assertIs(modules[k1], module_dict[k2])
1015*da0073e9SAndroid Build Coastguard Worker            for k in module_dict:
1016*da0073e9SAndroid Build Coastguard Worker                self.assertIs(module_dict[k], modules[k])
1017*da0073e9SAndroid Build Coastguard Worker            for k in module_dict.keys():
1018*da0073e9SAndroid Build Coastguard Worker                self.assertIs(module_dict[k], modules[k])
1019*da0073e9SAndroid Build Coastguard Worker            for k, v in module_dict.items():
1020*da0073e9SAndroid Build Coastguard Worker                self.assertIs(modules[k], v)
1021*da0073e9SAndroid Build Coastguard Worker            for k1, m2 in zip(modules, module_dict.values()):
1022*da0073e9SAndroid Build Coastguard Worker                self.assertIs(modules[k1], m2)
1023*da0073e9SAndroid Build Coastguard Worker            for k in modules.keys():
1024*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(k in module_dict)
1025*da0073e9SAndroid Build Coastguard Worker        check()
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        modules['conv'] = nn.Conv2d(3, 4, 3)
1028*da0073e9SAndroid Build Coastguard Worker        module_dict['conv'] = modules['conv']
1029*da0073e9SAndroid Build Coastguard Worker        check()
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker        next_modules = [
1032*da0073e9SAndroid Build Coastguard Worker            ('fc2', nn.Linear(5, 5)),
1033*da0073e9SAndroid Build Coastguard Worker            ('act', nn.Sigmoid()),
1034*da0073e9SAndroid Build Coastguard Worker        ]
1035*da0073e9SAndroid Build Coastguard Worker        modules.update(next_modules)
1036*da0073e9SAndroid Build Coastguard Worker        module_dict.update(next_modules)
1037*da0073e9SAndroid Build Coastguard Worker        check()
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker        next_modules = OrderedDict([
1040*da0073e9SAndroid Build Coastguard Worker            ('fc3', nn.Linear(5, 5)),
1041*da0073e9SAndroid Build Coastguard Worker            ('act2', nn.Sigmoid()),
1042*da0073e9SAndroid Build Coastguard Worker        ])
1043*da0073e9SAndroid Build Coastguard Worker        modules.update(next_modules)
1044*da0073e9SAndroid Build Coastguard Worker        module_dict.update(next_modules)
1045*da0073e9SAndroid Build Coastguard Worker        check()
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker        next_modules = {
1048*da0073e9SAndroid Build Coastguard Worker            'fc4': nn.Linear(5, 5),
1049*da0073e9SAndroid Build Coastguard Worker            'act3': nn.Sigmoid()
1050*da0073e9SAndroid Build Coastguard Worker        }
1051*da0073e9SAndroid Build Coastguard Worker        modules.update(next_modules.items())
1052*da0073e9SAndroid Build Coastguard Worker        module_dict.update(next_modules)
1053*da0073e9SAndroid Build Coastguard Worker        check()
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker        next_modules = nn.ModuleDict([
1056*da0073e9SAndroid Build Coastguard Worker            ('fc5', nn.Linear(5, 5)),
1057*da0073e9SAndroid Build Coastguard Worker            ('act4', nn.Sigmoid()),
1058*da0073e9SAndroid Build Coastguard Worker        ])
1059*da0073e9SAndroid Build Coastguard Worker        modules.update(next_modules)
1060*da0073e9SAndroid Build Coastguard Worker        module_dict.update(next_modules)
1061*da0073e9SAndroid Build Coastguard Worker        check()
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker        del module_dict['fc']
1064*da0073e9SAndroid Build Coastguard Worker        del modules['fc']
1065*da0073e9SAndroid Build Coastguard Worker        check()
1066*da0073e9SAndroid Build Coastguard Worker
1067*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1068*da0073e9SAndroid Build Coastguard Worker            module_dict.update(nn.ReLU())
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1071*da0073e9SAndroid Build Coastguard Worker            module_dict.update([nn.ReLU()])
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1074*da0073e9SAndroid Build Coastguard Worker            module_dict.update([[nn.ReLU()]])
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1077*da0073e9SAndroid Build Coastguard Worker            module_dict[1] = nn.ReLU()
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential(modules)
1080*da0073e9SAndroid Build Coastguard Worker        module_dict = nn.ModuleDict(s.named_children())
1081*da0073e9SAndroid Build Coastguard Worker        check()
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        c = module_dict.pop('conv')
1084*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c, modules['conv'])
1085*da0073e9SAndroid Build Coastguard Worker        modules.pop('conv')
1086*da0073e9SAndroid Build Coastguard Worker        check()
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        module_dict.clear()
1089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(module_dict), 0)
1090*da0073e9SAndroid Build Coastguard Worker        modules.clear()
1091*da0073e9SAndroid Build Coastguard Worker        check()
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker        # verify the right exception is thrown when trying to "forward" through a ModuleDict
1094*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, module_dict)
1095*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(NotImplementedError, module_dict, torch.rand(1, 3))
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
1098*da0073e9SAndroid Build Coastguard Worker    def test_ParameterList(self):
1099*da0073e9SAndroid Build Coastguard Worker        def make_param():
1100*da0073e9SAndroid Build Coastguard Worker            return Parameter(torch.randn(2, 2))
1101*da0073e9SAndroid Build Coastguard Worker        parameters = [make_param(), make_param()]
1102*da0073e9SAndroid Build Coastguard Worker        param_list = nn.ParameterList(parameters)
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker        def check():
1105*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(parameters), len(param_list))
1106*da0073e9SAndroid Build Coastguard Worker            for p1, p2 in zip(parameters, param_list):
1107*da0073e9SAndroid Build Coastguard Worker                self.assertIs(p1, p2)
1108*da0073e9SAndroid Build Coastguard Worker            for p1, p2 in zip(filter(lambda x: isinstance(x, Parameter), parameters), param_list.parameters()):
1109*da0073e9SAndroid Build Coastguard Worker                self.assertIs(p1, p2)
1110*da0073e9SAndroid Build Coastguard Worker            for i in range(len(parameters)):
1111*da0073e9SAndroid Build Coastguard Worker                self.assertIs(parameters[i], param_list[i])
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker        check()
1114*da0073e9SAndroid Build Coastguard Worker        parameters += [make_param()]
1115*da0073e9SAndroid Build Coastguard Worker        param_list += [parameters[-1]]
1116*da0073e9SAndroid Build Coastguard Worker        check()
1117*da0073e9SAndroid Build Coastguard Worker        parameters.append(make_param())
1118*da0073e9SAndroid Build Coastguard Worker        param_list.append(parameters[-1])
1119*da0073e9SAndroid Build Coastguard Worker        check()
1120*da0073e9SAndroid Build Coastguard Worker        next_params = [make_param(), make_param()]
1121*da0073e9SAndroid Build Coastguard Worker        parameters.extend(next_params)
1122*da0073e9SAndroid Build Coastguard Worker        param_list.extend(next_params)
1123*da0073e9SAndroid Build Coastguard Worker        check()
1124*da0073e9SAndroid Build Coastguard Worker        parameters[2] = make_param()
1125*da0073e9SAndroid Build Coastguard Worker        param_list[2] = parameters[2]
1126*da0073e9SAndroid Build Coastguard Worker        check()
1127*da0073e9SAndroid Build Coastguard Worker        parameters[-1] = make_param()
1128*da0073e9SAndroid Build Coastguard Worker        param_list[-1] = parameters[-1]
1129*da0073e9SAndroid Build Coastguard Worker        check()
1130*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor(2, dtype=torch.int32)
1131*da0073e9SAndroid Build Coastguard Worker        parameters[2] = make_param()
1132*da0073e9SAndroid Build Coastguard Worker        param_list[idx] = parameters[2]
1133*da0073e9SAndroid Build Coastguard Worker        self.assertIs(param_list[idx], parameters[2])
1134*da0073e9SAndroid Build Coastguard Worker        check()
1135*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param_list[1:], nn.ParameterList(parameters[1:]))
1136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param_list[3:], nn.ParameterList(parameters[3:]))
1137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param_list[:-1], nn.ParameterList(parameters[:-1]))
1138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param_list[:-3], nn.ParameterList(parameters[:-3]))
1139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param_list[::-1], nn.ParameterList(parameters[::-1]))
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1142*da0073e9SAndroid Build Coastguard Worker            param_list += make_param()
1143*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1144*da0073e9SAndroid Build Coastguard Worker            param_list.extend(make_param())
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(1, 2)
1147*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 3)
1148*da0073e9SAndroid Build Coastguard Worker        l3 = nn.Linear(3, 2)
1149*da0073e9SAndroid Build Coastguard Worker        l4 = nn.Linear(2, 3)
1150*da0073e9SAndroid Build Coastguard Worker        subnet = nn.Sequential(l3, l4)
1151*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential(
1152*da0073e9SAndroid Build Coastguard Worker            OrderedDict([
1153*da0073e9SAndroid Build Coastguard Worker                ("layer1", l1),
1154*da0073e9SAndroid Build Coastguard Worker                ("layer2", l2),
1155*da0073e9SAndroid Build Coastguard Worker                ("layer3", l3),
1156*da0073e9SAndroid Build Coastguard Worker                ("layer4", l4),
1157*da0073e9SAndroid Build Coastguard Worker                ("subnet_layer", subnet)
1158*da0073e9SAndroid Build Coastguard Worker            ])
1159*da0073e9SAndroid Build Coastguard Worker        )
1160*da0073e9SAndroid Build Coastguard Worker        parameters = list(s.parameters())
1161*da0073e9SAndroid Build Coastguard Worker        param_list = nn.ParameterList()
1162*da0073e9SAndroid Build Coastguard Worker        param_list.extend(s.parameters())
1163*da0073e9SAndroid Build Coastguard Worker        check()
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker        param_list.append(torch.rand(2, 2))
1166*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(param_list[-1], Parameter)
1167*da0073e9SAndroid Build Coastguard Worker        parameters.append(param_list[-1])
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker        param_list.extend([torch.rand(2, 2), "foo"])
1170*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(param_list[-2], Parameter)
1171*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(param_list[-1], str)
1172*da0073e9SAndroid Build Coastguard Worker        parameters.extend(param_list[-2:])
1173*da0073e9SAndroid Build Coastguard Worker
1174*da0073e9SAndroid Build Coastguard Worker        param_list += ["bar", torch.rand(2, 2)]
1175*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(param_list[-2], str)
1176*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(param_list[-1], Parameter)
1177*da0073e9SAndroid Build Coastguard Worker        parameters += param_list[-2:]
1178*da0073e9SAndroid Build Coastguard Worker        check()
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker    def test_ParameterList_meta(self):
1181*da0073e9SAndroid Build Coastguard Worker        p = torch.nn.Parameter(torch.empty(1, device='meta'))
1182*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(p), """\
1183*da0073e9SAndroid Build Coastguard WorkerParameter containing:
1184*da0073e9SAndroid Build Coastguard Workertensor(..., device='meta', size=(1,), requires_grad=True)""")
1185*da0073e9SAndroid Build Coastguard Worker        pl = torch.nn.ParameterList([p])
1186*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(pl), """ParameterList(  (0): Parameter containing: [torch.float32 of size 1])""")
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker    def test_ParameterList_replication(self):
1189*da0073e9SAndroid Build Coastguard Worker        # The actual replication code from DP cannot be used on CPU so doing it manually here
1190*da0073e9SAndroid Build Coastguard Worker        def make_param():
1191*da0073e9SAndroid Build Coastguard Worker            return Parameter(torch.randn(2, 2))
1192*da0073e9SAndroid Build Coastguard Worker        parameters = [make_param(), make_param()]
1193*da0073e9SAndroid Build Coastguard Worker        param_list = nn.ParameterList(parameters)
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker        new_param_list = param_list._replicate_for_data_parallel()
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker        for n, p in param_list.named_parameters():
1198*da0073e9SAndroid Build Coastguard Worker            # Do a view here so that we can check the base later
1199*da0073e9SAndroid Build Coastguard Worker            setattr(new_param_list, n, p.view_as(p))
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker        for p, p2 in zip(param_list, new_param_list):
1202*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p, p2)
1203*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(p2.grad_fn)
1204*da0073e9SAndroid Build Coastguard Worker            self.assertIs(p2._base, p)
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker    def test_ParameterDict(self):
1207*da0073e9SAndroid Build Coastguard Worker        parameters = OrderedDict([
1208*da0073e9SAndroid Build Coastguard Worker            ('p1', Parameter(torch.randn(10, 10))),
1209*da0073e9SAndroid Build Coastguard Worker            ('p2', Parameter(torch.randn(10, 10))),
1210*da0073e9SAndroid Build Coastguard Worker            ('p3', Parameter(torch.randn(10, 10))),
1211*da0073e9SAndroid Build Coastguard Worker        ])
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker        parameter_dict = nn.ParameterDict(parameters)
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker        def check():
1216*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(parameter_dict), len(parameters))
1217*da0073e9SAndroid Build Coastguard Worker            for i, (k1, (k2, m2)) in enumerate(zip(parameters, parameter_dict.named_parameters())):
1218*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(k1, k2)
1219*da0073e9SAndroid Build Coastguard Worker                self.assertIs(parameters[k1], m2)
1220*da0073e9SAndroid Build Coastguard Worker            for k1, k2 in zip(parameters, parameter_dict):
1221*da0073e9SAndroid Build Coastguard Worker                self.assertIs(parameters[k1], parameter_dict[k2])
1222*da0073e9SAndroid Build Coastguard Worker            for k in parameter_dict:
1223*da0073e9SAndroid Build Coastguard Worker                self.assertIs(parameter_dict[k], parameters[k])
1224*da0073e9SAndroid Build Coastguard Worker            for k in parameter_dict.keys():
1225*da0073e9SAndroid Build Coastguard Worker                self.assertIs(parameter_dict[k], parameters[k])
1226*da0073e9SAndroid Build Coastguard Worker            for k, v in parameter_dict.items():
1227*da0073e9SAndroid Build Coastguard Worker                self.assertIs(v, parameters[k])
1228*da0073e9SAndroid Build Coastguard Worker            for k1, m2 in zip(parameters, parameter_dict.values()):
1229*da0073e9SAndroid Build Coastguard Worker                self.assertIs(parameters[k1], m2)
1230*da0073e9SAndroid Build Coastguard Worker            for k in parameters.keys():
1231*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(k in parameter_dict)
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker        check()
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker        parameters['p4'] = Parameter(torch.randn(10, 10))
1236*da0073e9SAndroid Build Coastguard Worker        parameter_dict['p4'] = parameters['p4']
1237*da0073e9SAndroid Build Coastguard Worker        check()
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker        next_parameters = [
1240*da0073e9SAndroid Build Coastguard Worker            ('p5', Parameter(torch.randn(10, 10))),
1241*da0073e9SAndroid Build Coastguard Worker            ('p2', Parameter(torch.randn(10, 10))),
1242*da0073e9SAndroid Build Coastguard Worker        ]
1243*da0073e9SAndroid Build Coastguard Worker        parameters.update(next_parameters)
1244*da0073e9SAndroid Build Coastguard Worker        parameter_dict.update(next_parameters)
1245*da0073e9SAndroid Build Coastguard Worker        check()
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker        next_parameters = OrderedDict([
1248*da0073e9SAndroid Build Coastguard Worker            ('p6', Parameter(torch.randn(10, 10))),
1249*da0073e9SAndroid Build Coastguard Worker            ('p5', Parameter(torch.randn(10, 10))),
1250*da0073e9SAndroid Build Coastguard Worker        ])
1251*da0073e9SAndroid Build Coastguard Worker        parameters.update(next_parameters)
1252*da0073e9SAndroid Build Coastguard Worker        parameter_dict.update(next_parameters)
1253*da0073e9SAndroid Build Coastguard Worker        check()
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker        next_parameters = {
1256*da0073e9SAndroid Build Coastguard Worker            'p8': Parameter(torch.randn(10, 10)),
1257*da0073e9SAndroid Build Coastguard Worker            'p7': Parameter(torch.randn(10, 10))
1258*da0073e9SAndroid Build Coastguard Worker        }
1259*da0073e9SAndroid Build Coastguard Worker        parameters.update(sorted(next_parameters.items()))
1260*da0073e9SAndroid Build Coastguard Worker        parameter_dict.update(next_parameters)
1261*da0073e9SAndroid Build Coastguard Worker        check()
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker        next_parameters = nn.ParameterDict([
1264*da0073e9SAndroid Build Coastguard Worker            ('p10', Parameter(torch.randn(10, 10))),
1265*da0073e9SAndroid Build Coastguard Worker            ('p9', Parameter(torch.randn(10, 10))),
1266*da0073e9SAndroid Build Coastguard Worker        ])
1267*da0073e9SAndroid Build Coastguard Worker        parameters.update(next_parameters)
1268*da0073e9SAndroid Build Coastguard Worker        parameter_dict.update(next_parameters)
1269*da0073e9SAndroid Build Coastguard Worker        check()
1270*da0073e9SAndroid Build Coastguard Worker
1271*da0073e9SAndroid Build Coastguard Worker        del parameter_dict['p3']
1272*da0073e9SAndroid Build Coastguard Worker        del parameters['p3']
1273*da0073e9SAndroid Build Coastguard Worker        check()
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1276*da0073e9SAndroid Build Coastguard Worker            parameter_dict.update(1)
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1279*da0073e9SAndroid Build Coastguard Worker            parameter_dict.update([1])
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
1282*da0073e9SAndroid Build Coastguard Worker            parameter_dict.update(Parameter(torch.randn(10, 10)))
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker        p_pop = parameter_dict.pop('p4')
1285*da0073e9SAndroid Build Coastguard Worker        self.assertIs(p_pop, parameters['p4'])
1286*da0073e9SAndroid Build Coastguard Worker        parameters.pop('p4')
1287*da0073e9SAndroid Build Coastguard Worker        check()
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker        # Check reverse works
1290*da0073e9SAndroid Build Coastguard Worker        forward = list(iter(parameter_dict))
1291*da0073e9SAndroid Build Coastguard Worker        backward = list(reversed(parameter_dict))
1292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(forward), len(backward))
1293*da0073e9SAndroid Build Coastguard Worker        n = len(forward)
1294*da0073e9SAndroid Build Coastguard Worker        for i in range(n):
1295*da0073e9SAndroid Build Coastguard Worker            self.assertIs(forward[i], backward[n - i - 1])
1296*da0073e9SAndroid Build Coastguard Worker        check()
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Worker        # Check copy works
1299*da0073e9SAndroid Build Coastguard Worker        copy = parameter_dict.copy()
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker        # Check all keys are present and have shallow copied values
1302*da0073e9SAndroid Build Coastguard Worker        for key in parameter_dict:
1303*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(key in copy)
1304*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(parameter_dict[key], copy[key])
1305*da0073e9SAndroid Build Coastguard Worker            self.assertIs(parameter_dict[key], copy[key])
1306*da0073e9SAndroid Build Coastguard Worker        check()
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker        parameter_dict["p20"] = Parameter(torch.randn(10, 10))
1309*da0073e9SAndroid Build Coastguard Worker        copy["p21"] = Parameter(torch.randn(9, 10))
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("p20" in parameter_dict)
1312*da0073e9SAndroid Build Coastguard Worker        self.assertFalse("p20" in copy)
1313*da0073e9SAndroid Build Coastguard Worker        self.assertFalse("p21" in parameter_dict)
1314*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("p21" in copy)
1315*da0073e9SAndroid Build Coastguard Worker        parameter_dict.pop("p20")
1316*da0073e9SAndroid Build Coastguard Worker        check()
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker        p = Parameter(torch.randn(10, 10))
1319*da0073e9SAndroid Build Coastguard Worker        parameter_dict['p12'] = p
1320*da0073e9SAndroid Build Coastguard Worker        p_popitem = parameter_dict.popitem()
1321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p_popitem[0], 'p12')
1322*da0073e9SAndroid Build Coastguard Worker        self.assertIs(p_popitem[1], p)
1323*da0073e9SAndroid Build Coastguard Worker        check()
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker        # Unit test for set_default
1326*da0073e9SAndroid Build Coastguard Worker        # 1. Ensure parameter is correctly inserted when
1327*da0073e9SAndroid Build Coastguard Worker        #    the key is not present in `ParameterDict`
1328*da0073e9SAndroid Build Coastguard Worker        assert 'p11' not in parameter_dict
1329*da0073e9SAndroid Build Coastguard Worker        assert 'p11' not in parameters
1330*da0073e9SAndroid Build Coastguard Worker        parameters['p11'] = Parameter(torch.randn(10, 10))
1331*da0073e9SAndroid Build Coastguard Worker        p_setdefault = parameter_dict.setdefault('p11', parameters['p11'])
1332*da0073e9SAndroid Build Coastguard Worker        self.assertIs(p_setdefault, parameters['p11'])
1333*da0073e9SAndroid Build Coastguard Worker        self.assertIs(p_setdefault, parameter_dict['p11'])
1334*da0073e9SAndroid Build Coastguard Worker        check()
1335*da0073e9SAndroid Build Coastguard Worker        # 2. Ensure parameter is NOT inserted when the
1336*da0073e9SAndroid Build Coastguard Worker        #    key is already present in `ParameterDict`
1337*da0073e9SAndroid Build Coastguard Worker        p = Parameter(torch.randn(10, 10))
1338*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(parameter_dict.setdefault('p11', p) is p)
1339*da0073e9SAndroid Build Coastguard Worker        check()
1340*da0073e9SAndroid Build Coastguard Worker        # 3. Ensure `None` is inserted when the key is not
1341*da0073e9SAndroid Build Coastguard Worker        #    present in `Parameter` and parameter is not specified
1342*da0073e9SAndroid Build Coastguard Worker        self.assertIs(parameter_dict.setdefault('p26'), None)
1343*da0073e9SAndroid Build Coastguard Worker        del parameter_dict['p26']
1344*da0073e9SAndroid Build Coastguard Worker        check()
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker        parameters2 = OrderedDict([
1347*da0073e9SAndroid Build Coastguard Worker            ('p13', Parameter(torch.randn(10, 10))),
1348*da0073e9SAndroid Build Coastguard Worker            ('p2', Parameter(torch.randn(10, 10))),
1349*da0073e9SAndroid Build Coastguard Worker            ('p3', Parameter(torch.randn(10, 10))),
1350*da0073e9SAndroid Build Coastguard Worker        ])
1351*da0073e9SAndroid Build Coastguard Worker        parameter_dict2 = nn.ParameterDict(parameters2)
1352*da0073e9SAndroid Build Coastguard Worker        parameters.update(parameters2)
1353*da0073e9SAndroid Build Coastguard Worker        parameter_dict |= parameter_dict2
1354*da0073e9SAndroid Build Coastguard Worker        check()
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker        parameters2 = OrderedDict()
1357*da0073e9SAndroid Build Coastguard Worker        parameter_dict2 = nn.ParameterDict(parameters2)
1358*da0073e9SAndroid Build Coastguard Worker        parameters.update(parameters2)
1359*da0073e9SAndroid Build Coastguard Worker        parameter_dict |= parameter_dict2
1360*da0073e9SAndroid Build Coastguard Worker        check()
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker        parameters2 = OrderedDict([
1363*da0073e9SAndroid Build Coastguard Worker            ('p14', Parameter(torch.randn(10, 10))),
1364*da0073e9SAndroid Build Coastguard Worker            ('p15', Parameter(torch.randn(10, 10))),
1365*da0073e9SAndroid Build Coastguard Worker            ('p13', Parameter(torch.randn(10, 10))),
1366*da0073e9SAndroid Build Coastguard Worker        ])
1367*da0073e9SAndroid Build Coastguard Worker        parameter_dict2 = nn.ParameterDict(parameters2)
1368*da0073e9SAndroid Build Coastguard Worker        parameters.update(parameters2)
1369*da0073e9SAndroid Build Coastguard Worker        parameter_dict |= parameter_dict2
1370*da0073e9SAndroid Build Coastguard Worker        check()
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker        # Check __or__ and __ror__ works
1373*da0073e9SAndroid Build Coastguard Worker        parameters2 = OrderedDict([
1374*da0073e9SAndroid Build Coastguard Worker            ('p20', Parameter(torch.randn(10, 10))),
1375*da0073e9SAndroid Build Coastguard Worker            ('p21', Parameter(torch.randn(10, 10))),
1376*da0073e9SAndroid Build Coastguard Worker            ('p22', Parameter(torch.randn(10, 10))),
1377*da0073e9SAndroid Build Coastguard Worker        ])
1378*da0073e9SAndroid Build Coastguard Worker        parameter_dict2 = nn.ParameterDict(parameters2)
1379*da0073e9SAndroid Build Coastguard Worker        parameters.update(parameters2)
1380*da0073e9SAndroid Build Coastguard Worker        parameter_dict = parameter_dict | parameter_dict2
1381*da0073e9SAndroid Build Coastguard Worker        check()
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker        parameters2 = OrderedDict([
1384*da0073e9SAndroid Build Coastguard Worker            ('p23', Parameter(torch.randn(10, 10))),
1385*da0073e9SAndroid Build Coastguard Worker            ('p24', Parameter(torch.randn(10, 10))),
1386*da0073e9SAndroid Build Coastguard Worker            ('p25', Parameter(torch.randn(10, 10))),
1387*da0073e9SAndroid Build Coastguard Worker        ])
1388*da0073e9SAndroid Build Coastguard Worker        parameter_dict2 = nn.ParameterDict(parameters2)
1389*da0073e9SAndroid Build Coastguard Worker        parameters2.update(parameters)
1390*da0073e9SAndroid Build Coastguard Worker        parameters = parameters2
1391*da0073e9SAndroid Build Coastguard Worker        parameter_dict = parameter_dict2 | parameter_dict
1392*da0073e9SAndroid Build Coastguard Worker        check()
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker        parameters['p17'] = Parameter(torch.randn(10, 10))
1395*da0073e9SAndroid Build Coastguard Worker        parameter_dict['p17'] = parameters['p17']
1396*da0073e9SAndroid Build Coastguard Worker        self.assertIs(parameters['p17'], parameter_dict.get('p17'))
1397*da0073e9SAndroid Build Coastguard Worker        temp_param = Parameter(torch.randn(10, 10))
1398*da0073e9SAndroid Build Coastguard Worker        self.assertIs(parameters['p17'], parameter_dict.get('p17', temp_param))
1399*da0073e9SAndroid Build Coastguard Worker        self.assertIs(None, parameter_dict.get('p18'))
1400*da0073e9SAndroid Build Coastguard Worker        self.assertIs(temp_param, parameter_dict.get('p18', temp_param))
1401*da0073e9SAndroid Build Coastguard Worker        check()
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        parameter_dict.clear()
1404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(parameter_dict), 0)
1405*da0073e9SAndroid Build Coastguard Worker        parameters.clear()
1406*da0073e9SAndroid Build Coastguard Worker        check()
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker        parameter_dict2 = parameter_dict.fromkeys(['p19', 'p20'])
1409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({'p19': None, 'p20': None}, parameter_dict2)
1410*da0073e9SAndroid Build Coastguard Worker        check()
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker        parameter_dict2 = parameter_dict.fromkeys(['p19', 'p20'], temp_param)
1413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual({'p19': temp_param, 'p20': temp_param}, parameter_dict2)
1414*da0073e9SAndroid Build Coastguard Worker        check()
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker        parameter_dict['p21'] = torch.rand(2, 2)
1417*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(parameter_dict['p21'], Parameter)
1418*da0073e9SAndroid Build Coastguard Worker        parameters['p21'] = parameter_dict['p21']
1419*da0073e9SAndroid Build Coastguard Worker
1420*da0073e9SAndroid Build Coastguard Worker        parameter_dict.update({'p22': torch.rand(2, 2), 'foo': 'bar'})
1421*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(parameter_dict['p22'], Parameter)
1422*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(parameter_dict['foo'], str)
1423*da0073e9SAndroid Build Coastguard Worker        parameters['p22'] = parameter_dict['p22']
1424*da0073e9SAndroid Build Coastguard Worker        parameters['foo'] = parameter_dict['foo']
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker    def test_ParameterDict_replication(self):
1427*da0073e9SAndroid Build Coastguard Worker        # The actual replication code from DP cannot be used on CPU so doing it manually here
1428*da0073e9SAndroid Build Coastguard Worker        def make_param():
1429*da0073e9SAndroid Build Coastguard Worker            return Parameter(torch.randn(2, 2))
1430*da0073e9SAndroid Build Coastguard Worker        parameters = {"foo": make_param(), "bar": make_param()}
1431*da0073e9SAndroid Build Coastguard Worker        param_dict = nn.ParameterDict(parameters)
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker        new_param_dict = param_dict._replicate_for_data_parallel()
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker        for n, p in param_dict.named_parameters():
1436*da0073e9SAndroid Build Coastguard Worker            # Do a view here so that we can check the base later
1437*da0073e9SAndroid Build Coastguard Worker            setattr(new_param_dict, n, p.view_as(p))
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker        for (k, p), (k2, p2) in zip(param_dict.items(), new_param_dict.items()):
1440*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(k, k2)
1441*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p, p2)
1442*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(p2.grad_fn)
1443*da0073e9SAndroid Build Coastguard Worker            self.assertIs(p2._base, p)
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param_dict["foo"], new_param_dict["foo"])
1446*da0073e9SAndroid Build Coastguard Worker
1447*da0073e9SAndroid Build Coastguard Worker    def test_add_module(self):
1448*da0073e9SAndroid Build Coastguard Worker        methods_to_test = ['add_module', 'register_module']
1449*da0073e9SAndroid Build Coastguard Worker        for fn in methods_to_test:
1450*da0073e9SAndroid Build Coastguard Worker            l = nn.Linear(10, 20)
1451*da0073e9SAndroid Build Coastguard Worker            net = nn.Module()
1452*da0073e9SAndroid Build Coastguard Worker            net.l = l
1453*da0073e9SAndroid Build Coastguard Worker            net.l2 = l
1454*da0073e9SAndroid Build Coastguard Worker            getattr(net, fn)('empty', None)
1455*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(net.l, l)
1456*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(net.l2, l)
1457*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(net.empty, None)
1458*da0073e9SAndroid Build Coastguard Worker            getattr(net, fn)('l3', l)
1459*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(net.l3, l)
1460*da0073e9SAndroid Build Coastguard Worker            l3 = nn.Linear(20, 10)
1461*da0073e9SAndroid Build Coastguard Worker            getattr(net, fn)('l', l3)
1462*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(net.l, l3)
1463*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(TypeError, lambda: getattr(net, fn)('x', 'non-module'))
1464*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(TypeError, 'module name should be a string. Got int',
1465*da0073e9SAndroid Build Coastguard Worker                                   lambda: getattr(net, fn)(1, l))
1466*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType',
1467*da0073e9SAndroid Build Coastguard Worker                                   lambda: getattr(net, fn)(None, l))
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker    def test_set_submodule(self):
1470*da0073e9SAndroid Build Coastguard Worker        net = nn.Module()
1471*da0073e9SAndroid Build Coastguard Worker        net.t = nn.Module()
1472*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(1, 2)
1473*da0073e9SAndroid Build Coastguard Worker        target = "t.l"
1474*da0073e9SAndroid Build Coastguard Worker        net.set_submodule(target, l)
1475*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(net.get_submodule(target), l)
1476*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(2, 1)
1477*da0073e9SAndroid Build Coastguard Worker        net.set_submodule(target, l2)
1478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(net.get_submodule(target), l2)
1479*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, net.set_submodule, "", l)
1480*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(AttributeError, net.set_submodule, "a.l", l)
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker    def test_module_to_argparse(self):
1483*da0073e9SAndroid Build Coastguard Worker        net = nn.Sequential(nn.Linear(3, 3))
1484*da0073e9SAndroid Build Coastguard Worker        cpu = torch.device('cpu')
1485*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1486*da0073e9SAndroid Build Coastguard Worker            net.to(cpu, True)
1487*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1488*da0073e9SAndroid Build Coastguard Worker            net.to(torch.long)
1489*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1490*da0073e9SAndroid Build Coastguard Worker            net.to(None, True)
1491*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1492*da0073e9SAndroid Build Coastguard Worker            net.to(cpu, torch.long, True)
1493*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1494*da0073e9SAndroid Build Coastguard Worker            net.to(cpu, dtype=torch.long, non_blocking=True)
1495*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1496*da0073e9SAndroid Build Coastguard Worker            net.to([])
1497*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1498*da0073e9SAndroid Build Coastguard Worker            net.to({}, non_blocking=True)
1499*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1500*da0073e9SAndroid Build Coastguard Worker            net.to(torch.tensor(3, dtype=torch.long), non_blocking=True)
1501*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
1502*da0073e9SAndroid Build Coastguard Worker            net.to(cpu, torch.tensor(3, dtype=torch.long), non_blocking=True)
1503*da0073e9SAndroid Build Coastguard Worker
1504*da0073e9SAndroid Build Coastguard Worker    def test_RNN_nonlinearity(self):
1505*da0073e9SAndroid Build Coastguard Worker        rnn = torch.nn.RNN(1, 10)
1506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rnn.nonlinearity, 'tanh')
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker        rnn = torch.nn.RNN(1, 10, nonlinearity='relu')
1509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rnn.nonlinearity, 'relu')
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, 'Unknown nonlinearity'):
1512*da0073e9SAndroid Build Coastguard Worker            rnn = torch.nn.RNN(1, 10, nonlinearity='garbage')
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker    def test_RNN_nonlinearity_passed_as_arg(self):
1515*da0073e9SAndroid Build Coastguard Worker        rnn = torch.nn.RNN(2, 3, 1, 'relu')
1516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(rnn.nonlinearity, 'relu')
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker    def test_module_apply_inplace_op(self):
1519*da0073e9SAndroid Build Coastguard Worker        def add_one_inplace(t):
1520*da0073e9SAndroid Build Coastguard Worker            return t.add_(1.0)
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Worker        # Test that applying an in-place operation to a module would bump
1523*da0073e9SAndroid Build Coastguard Worker        # the module's parameters' version counter.
1524*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(20, 10)
1525*da0073e9SAndroid Build Coastguard Worker        pvm = m.weight.mul(m.weight)
1526*da0073e9SAndroid Build Coastguard Worker        m_weight_version_saved = m.weight._version
1527*da0073e9SAndroid Build Coastguard Worker        m = m._apply(add_one_inplace)
1528*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(m.weight._version, m_weight_version_saved)
1529*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1530*da0073e9SAndroid Build Coastguard Worker            pvm.backward(torch.randn(10, 20))
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker        # Test that applying an in-place operation to a module would bump
1533*da0073e9SAndroid Build Coastguard Worker        # the module's parameters' gradients' version counter.
1534*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(20, 10)
1535*da0073e9SAndroid Build Coastguard Worker        m.weight.grad = torch.randn(10, 20).requires_grad_()
1536*da0073e9SAndroid Build Coastguard Worker        pgm = m.weight.grad.mul(m.weight.grad)
1537*da0073e9SAndroid Build Coastguard Worker        m_weight_grad_version_saved = m.weight.grad._version
1538*da0073e9SAndroid Build Coastguard Worker        m = m._apply(add_one_inplace)
1539*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(m.weight.grad._version, m_weight_grad_version_saved)
1540*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1541*da0073e9SAndroid Build Coastguard Worker            pgm.backward(torch.randn(10, 20))
1542*da0073e9SAndroid Build Coastguard Worker
1543*da0073e9SAndroid Build Coastguard Worker    def test_overwrite_module_params_on_conversion(self):
1544*da0073e9SAndroid Build Coastguard Worker        # Test that if the conversion function passed to `module._apply()`
1545*da0073e9SAndroid Build Coastguard Worker        # changes the TensorImpl type of `module`'s parameters, the `module`'s
1546*da0073e9SAndroid Build Coastguard Worker        # parameters are always overwritten, regardless of the value of
1547*da0073e9SAndroid Build Coastguard Worker        # `torch.__future__.get_overwrite_module_params_on_conversion()`.
1548*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(20, 10)
1549*da0073e9SAndroid Build Coastguard Worker        m.weight.grad = torch.randn(10, 20)
1550*da0073e9SAndroid Build Coastguard Worker        weight_ref = m.weight
1551*da0073e9SAndroid Build Coastguard Worker        weight_grad_ref = m.weight.grad
1552*da0073e9SAndroid Build Coastguard Worker        m = m._apply(lambda t: torch.sparse_coo_tensor(torch.zeros([2, 1]), torch.ones([1]), torch.Size([10, 20])))
1553*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(weight_ref.layout, m.weight.layout)
1554*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(weight_grad_ref.layout, m.weight.grad.layout)
1555*da0073e9SAndroid Build Coastguard Worker
1556*da0073e9SAndroid Build Coastguard Worker        # Test that under the current default settings
1557*da0073e9SAndroid Build Coastguard Worker        # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
1558*da0073e9SAndroid Build Coastguard Worker        # a view to a module's parameters is not pointing to the same storage as
1559*da0073e9SAndroid Build Coastguard Worker        # its base variable after converting the module to a different dtype.
1560*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(20, 10).float()
1561*da0073e9SAndroid Build Coastguard Worker        mw = m.weight[:]
1562*da0073e9SAndroid Build Coastguard Worker        m.double()
1563*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1564*da0073e9SAndroid Build Coastguard Worker            mw[0][0] = 5
1565*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mw[0][0].dtype == torch.float)
1566*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mw._base[0][0].dtype == torch.double)
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker        try:
1569*da0073e9SAndroid Build Coastguard Worker            torch.__future__.set_overwrite_module_params_on_conversion(True)
1570*da0073e9SAndroid Build Coastguard Worker
1571*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1572*da0073e9SAndroid Build Coastguard Worker            # a view to a module's parameters is still pointing to the same storage as
1573*da0073e9SAndroid Build Coastguard Worker            # its base variable after converting the module to a different dtype.
1574*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10).float()
1575*da0073e9SAndroid Build Coastguard Worker            mw = m.weight[:]
1576*da0073e9SAndroid Build Coastguard Worker            m.double()
1577*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1578*da0073e9SAndroid Build Coastguard Worker                mw[0][0] = 5
1579*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(mw[0][0] == mw._base[0][0])
1580*da0073e9SAndroid Build Coastguard Worker
1581*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1582*da0073e9SAndroid Build Coastguard Worker            # `float_module.double()` doesn't preserve previous references to
1583*da0073e9SAndroid Build Coastguard Worker            # `float_module`'s parameters or gradients.
1584*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10).float()
1585*da0073e9SAndroid Build Coastguard Worker            m.weight.grad = torch.randn(10, 20).float()
1586*da0073e9SAndroid Build Coastguard Worker            weight_ref = m.weight
1587*da0073e9SAndroid Build Coastguard Worker            weight_grad_ref = m.weight.grad
1588*da0073e9SAndroid Build Coastguard Worker            m.double()
1589*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(weight_ref.dtype, m.weight.dtype)
1590*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(weight_grad_ref.dtype, m.weight.grad.dtype)
1591*da0073e9SAndroid Build Coastguard Worker
1592*da0073e9SAndroid Build Coastguard Worker            def add_one_inplace(t):
1593*da0073e9SAndroid Build Coastguard Worker                return t.add_(1.0)
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1596*da0073e9SAndroid Build Coastguard Worker            # applying an in-place operation to a module would bump the module's
1597*da0073e9SAndroid Build Coastguard Worker            # original parameters' version counter.
1598*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10)
1599*da0073e9SAndroid Build Coastguard Worker            pvm = m.weight.mul(m.weight)
1600*da0073e9SAndroid Build Coastguard Worker            weight_ref = m.weight
1601*da0073e9SAndroid Build Coastguard Worker            m_weight_version_saved = weight_ref._version
1602*da0073e9SAndroid Build Coastguard Worker            m = m._apply(add_one_inplace)
1603*da0073e9SAndroid Build Coastguard Worker            # Test that the in-place operation bumps the original parameter's version counter
1604*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(weight_ref._version, m_weight_version_saved)
1605*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1606*da0073e9SAndroid Build Coastguard Worker                pvm.backward(torch.randn(10, 20))
1607*da0073e9SAndroid Build Coastguard Worker
1608*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1609*da0073e9SAndroid Build Coastguard Worker            # applying an in-place operation to a module would bump the module's
1610*da0073e9SAndroid Build Coastguard Worker            # original parameters' gradients' version counter.
1611*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10)
1612*da0073e9SAndroid Build Coastguard Worker            m.weight.grad = torch.randn(10, 20).requires_grad_()
1613*da0073e9SAndroid Build Coastguard Worker            pgm = m.weight.grad.mul(m.weight.grad)
1614*da0073e9SAndroid Build Coastguard Worker            weight_grad_ref = m.weight.grad
1615*da0073e9SAndroid Build Coastguard Worker            m_weight_grad_version_saved = weight_grad_ref._version
1616*da0073e9SAndroid Build Coastguard Worker            m = m._apply(add_one_inplace)
1617*da0073e9SAndroid Build Coastguard Worker            self.assertGreater(weight_grad_ref._version, m_weight_grad_version_saved)
1618*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1619*da0073e9SAndroid Build Coastguard Worker                pgm.backward(torch.randn(10, 20))
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1622*da0073e9SAndroid Build Coastguard Worker            # applying an out-of-place operation to a module doesn't bump
1623*da0073e9SAndroid Build Coastguard Worker            # the module's original parameters' version counter.
1624*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10)
1625*da0073e9SAndroid Build Coastguard Worker            weight_ref = m.weight
1626*da0073e9SAndroid Build Coastguard Worker            m_weight_version_saved = weight_ref._version
1627*da0073e9SAndroid Build Coastguard Worker            m = m._apply(lambda t: torch.randn(t.shape))
1628*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(weight_ref._version, m_weight_version_saved)
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1631*da0073e9SAndroid Build Coastguard Worker            # applying an out-of-place operation to a module doesn't bump
1632*da0073e9SAndroid Build Coastguard Worker            # the module's original parameters' gradients' version counter.
1633*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10)
1634*da0073e9SAndroid Build Coastguard Worker            m.weight.grad = torch.randn(10, 20).requires_grad_()
1635*da0073e9SAndroid Build Coastguard Worker            weight_grad_ref = m.weight.grad
1636*da0073e9SAndroid Build Coastguard Worker            m_weight_grad_version_saved = weight_grad_ref._version
1637*da0073e9SAndroid Build Coastguard Worker            m = m._apply(lambda t: torch.randn(t.shape))
1638*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(weight_grad_ref._version, m_weight_grad_version_saved)
1639*da0073e9SAndroid Build Coastguard Worker        finally:
1640*da0073e9SAndroid Build Coastguard Worker            torch.__future__.set_overwrite_module_params_on_conversion(False)
1641*da0073e9SAndroid Build Coastguard Worker
1642*da0073e9SAndroid Build Coastguard Worker    def test_swap_module_params_poisons_acc_grad(self):
1643*da0073e9SAndroid Build Coastguard Worker        try:
1644*da0073e9SAndroid Build Coastguard Worker            torch.__future__.set_swap_module_params_on_conversion(True)
1645*da0073e9SAndroid Build Coastguard Worker            # (1) backward cannot be run after _apply
1646*da0073e9SAndroid Build Coastguard Worker            # forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
1647*da0073e9SAndroid Build Coastguard Worker            # additionally, if any Tensors are saved for backward, their use_count will be bumped
1648*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.Linear(2, 3)
1649*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(2, 2)
1650*da0073e9SAndroid Build Coastguard Worker            out = m(inp)
1651*da0073e9SAndroid Build Coastguard Worker            m.half()
1652*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters()))
1653*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"):
1654*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
1655*da0073e9SAndroid Build Coastguard Worker            # (2) _apply can be run after backward()
1656*da0073e9SAndroid Build Coastguard Worker            # After running backward, all the references generated by "save for backward" will be cleared
1657*da0073e9SAndroid Build Coastguard Worker            # So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors
1658*da0073e9SAndroid Build Coastguard Worker            # should allow this.
1659*da0073e9SAndroid Build Coastguard Worker            inp2 = torch.randn(2, 2, dtype=torch.half)
1660*da0073e9SAndroid Build Coastguard Worker            out2 = m(inp2)
1661*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
1662*da0073e9SAndroid Build Coastguard Worker            m.float()
1663*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters()))
1664*da0073e9SAndroid Build Coastguard Worker            out3 = m(inp)
1665*da0073e9SAndroid Build Coastguard Worker        finally:
1666*da0073e9SAndroid Build Coastguard Worker            torch.__future__.set_swap_module_params_on_conversion(False)
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker    def test_type(self):
1669*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(10, 20)
1670*da0073e9SAndroid Build Coastguard Worker        net = nn.Module()
1671*da0073e9SAndroid Build Coastguard Worker        net.l = l
1672*da0073e9SAndroid Build Coastguard Worker        net.l2 = l
1673*da0073e9SAndroid Build Coastguard Worker        net.add_module('empty', None)
1674*da0073e9SAndroid Build Coastguard Worker        net.indices = Buffer(torch.LongTensor(1))
1675*da0073e9SAndroid Build Coastguard Worker        net.float()
1676*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.weight.data, torch.FloatTensor)
1677*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.bias.data, torch.FloatTensor)
1678*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(net.indices, torch.LongTensor)
1679*da0073e9SAndroid Build Coastguard Worker        net.double()
1680*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.weight.data, torch.DoubleTensor)
1681*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.bias.data, torch.DoubleTensor)
1682*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(net.indices, torch.LongTensor)
1683*da0073e9SAndroid Build Coastguard Worker        net.to(torch.half)
1684*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.weight.data, torch.HalfTensor)
1685*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.bias.data, torch.HalfTensor)
1686*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(net.indices, torch.LongTensor)
1687*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
1688*da0073e9SAndroid Build Coastguard Worker            net.float().cuda()
1689*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.weight.data, torch.cuda.FloatTensor)
1690*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.bias.data, torch.cuda.FloatTensor)
1691*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1692*da0073e9SAndroid Build Coastguard Worker            net.cpu()
1693*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.weight.data, torch.FloatTensor)
1694*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.bias.data, torch.FloatTensor)
1695*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(net.indices, torch.LongTensor)
1696*da0073e9SAndroid Build Coastguard Worker            net.to("cuda", torch.double, True)
1697*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.weight.data, torch.cuda.DoubleTensor)
1698*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.bias.data, torch.cuda.DoubleTensor)
1699*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1700*da0073e9SAndroid Build Coastguard Worker            net.to(torch.empty(1, device="cuda:0", dtype=torch.half))
1701*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.weight.data, torch.cuda.HalfTensor)
1702*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.bias.data, torch.cuda.HalfTensor)
1703*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1704*da0073e9SAndroid Build Coastguard Worker        net.to(torch.device("cpu"), non_blocking=True)
1705*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.weight.data, torch.HalfTensor)
1706*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.bias.data, torch.HalfTensor)
1707*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(net.indices, torch.LongTensor)
1708*da0073e9SAndroid Build Coastguard Worker        net.to(torch.float)
1709*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.weight.data, torch.FloatTensor)
1710*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.bias.data, torch.FloatTensor)
1711*da0073e9SAndroid Build Coastguard Worker        net.to(torch.DoubleTensor(1))
1712*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.weight.data, torch.DoubleTensor)
1713*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(l.bias.data, torch.DoubleTensor)
1714*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
1715*da0073e9SAndroid Build Coastguard Worker            net.to(device='cuda', dtype=torch.float)
1716*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.weight.data, torch.cuda.FloatTensor)
1717*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(l.bias.data, torch.cuda.FloatTensor)
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker    def test_non_leaf_parameters(self):
1720*da0073e9SAndroid Build Coastguard Worker        l1 = nn.Linear(10, 10)
1721*da0073e9SAndroid Build Coastguard Worker        l2 = nn.Linear(10, 10)
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker        def assign_weight():
1724*da0073e9SAndroid Build Coastguard Worker            l2.weight = l1.weight + 2
1725*da0073e9SAndroid Build Coastguard Worker
1726*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, assign_weight)
1727*da0073e9SAndroid Build Coastguard Worker        # This should work though
1728*da0073e9SAndroid Build Coastguard Worker        l2.weight = Parameter(torch.randn(10, 10))
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker    def test_parameters_to_vector(self):
1731*da0073e9SAndroid Build Coastguard Worker        conv1 = nn.Conv2d(3, 10, 5)
1732*da0073e9SAndroid Build Coastguard Worker        fc1 = nn.Linear(10, 20)
1733*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(conv1, fc1)
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker        vec = parameters_to_vector(model.parameters())
1736*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vec.size(0), 980)
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker    def test_vector_to_parameters(self):
1739*da0073e9SAndroid Build Coastguard Worker        conv1 = nn.Conv2d(3, 10, 5)
1740*da0073e9SAndroid Build Coastguard Worker        fc1 = nn.Linear(10, 20)
1741*da0073e9SAndroid Build Coastguard Worker        model = nn.Sequential(conv1, fc1)
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        vec = torch.arange(0., 980)
1744*da0073e9SAndroid Build Coastguard Worker        vector_to_parameters(vec, model.parameters())
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker        sample = next(model.parameters())[0, 0, 0]
1747*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(sample.data, vec.data[:5]))
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker    def test_rnn_weight_norm(self):
1750*da0073e9SAndroid Build Coastguard Worker        def check_weight_norm(l, name, num_params):
1751*da0073e9SAndroid Build Coastguard Worker            # This Module has 4 or 5 parameters called:
1752*da0073e9SAndroid Build Coastguard Worker            # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', weight_hr_l0
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker            # Applying weight norm on one of them causes it to become a tensor
1755*da0073e9SAndroid Build Coastguard Worker            l = torch.nn.utils.weight_norm(l, name=name)
1756*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1757*da0073e9SAndroid Build Coastguard Worker                sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights),
1758*da0073e9SAndroid Build Coastguard Worker                num_params - 1,
1759*da0073e9SAndroid Build Coastguard Worker            )
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker            # Removing the weight norm reparametrization restores the Parameter
1762*da0073e9SAndroid Build Coastguard Worker            l = torch.nn.utils.remove_weight_norm(l, name=name)
1763*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1764*da0073e9SAndroid Build Coastguard Worker                sum(isinstance(p, torch.nn.Parameter) for p in l._flat_weights),
1765*da0073e9SAndroid Build Coastguard Worker                num_params,
1766*da0073e9SAndroid Build Coastguard Worker            )
1767*da0073e9SAndroid Build Coastguard Worker
1768*da0073e9SAndroid Build Coastguard Worker            # Make sure that, upon removal of the reparametrization, the
1769*da0073e9SAndroid Build Coastguard Worker            # `._parameters` and `.named_parameters` contain the right params.
1770*da0073e9SAndroid Build Coastguard Worker            # Specifically, the original weight ('weight_ih_l0') should be placed
1771*da0073e9SAndroid Build Coastguard Worker            # back in the parameters, while the reparametrization components
1772*da0073e9SAndroid Build Coastguard Worker            # ('weight_ih_l0_v' and 'weight_ih_l0_g') should be removed.
1773*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(name in l._parameters)
1774*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(l._parameters[name])
1775*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(name + '_v' not in l._parameters)
1776*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(name + '_g' not in l._parameters)
1777*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(name in dict(l.named_parameters()))
1778*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(dict(l.named_parameters())[name])
1779*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(name + '_v' not in dict(l.named_parameters()))
1780*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(name + '_g' not in dict(l.named_parameters()))
1781*da0073e9SAndroid Build Coastguard Worker
1782*da0073e9SAndroid Build Coastguard Worker        check_weight_norm(torch.nn.LSTM(32, 32), 'weight_ih_l0', 4)
1783*da0073e9SAndroid Build Coastguard Worker        check_weight_norm(torch.nn.LSTM(32, 32, proj_size=16), 'weight_hr_l0', 5)
1784*da0073e9SAndroid Build Coastguard Worker
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm(self):
1787*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.bfloat16]:
1788*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 4, dtype=dtype)
1789*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(4, 5).to(dtype=dtype)
1790*da0073e9SAndroid Build Coastguard Worker            expected_output = m(input)
1791*da0073e9SAndroid Build Coastguard Worker
1792*da0073e9SAndroid Build Coastguard Worker            # add weight normalization
1793*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.weight_norm(m)
1794*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.weight_v.size(), m.weight.size())
1795*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.weight_g.size(), (5, 1))
1796*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker            # remove weight norm
1799*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.remove_weight_norm(m)
1800*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(hasattr(m, 'weight_g'))
1801*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(hasattr(m, 'weight_v'))
1802*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
1803*da0073e9SAndroid Build Coastguard Worker
1804*da0073e9SAndroid Build Coastguard Worker            # test with dim=1
1805*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.weight_norm(m, dim=1)
1806*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.weight_v.size(), m.weight.size())
1807*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.weight_g.size(), (1, 4))
1808*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0)
1809*da0073e9SAndroid Build Coastguard Worker
1810*da0073e9SAndroid Build Coastguard Worker            # test with dim=None
1811*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(4, 5).to(dtype=dtype)
1812*da0073e9SAndroid Build Coastguard Worker            expected_output = m(input)
1813*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.weight_norm(m, dim=None)
1814*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output)
1815*da0073e9SAndroid Build Coastguard Worker
1816*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'register two weight_norm hooks'):
1817*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.weight_norm(m)
1818*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.weight_norm(m)
1819*da0073e9SAndroid Build Coastguard Worker
1820*da0073e9SAndroid Build Coastguard Worker        # For float16, the forward of the Module doesn't work but we must still be able
1821*da0073e9SAndroid Build Coastguard Worker        # to register the weight norm as this is often done before sending the Module to
1822*da0073e9SAndroid Build Coastguard Worker        # CUDA.
1823*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(4, 5, dtype=torch.float16)
1824*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.weight_norm(m)
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker    def test_parameterlistdict_setting_attributes(self):
1827*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1828*da0073e9SAndroid Build Coastguard Worker            mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
1829*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1830*da0073e9SAndroid Build Coastguard Worker
1831*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1832*da0073e9SAndroid Build Coastguard Worker            mod.train()
1833*da0073e9SAndroid Build Coastguard Worker            mod.eval()
1834*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1837*da0073e9SAndroid Build Coastguard Worker            mod = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
1838*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1839*da0073e9SAndroid Build Coastguard Worker
1840*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1841*da0073e9SAndroid Build Coastguard Worker            mod.train()
1842*da0073e9SAndroid Build Coastguard Worker            mod.eval()
1843*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1844*da0073e9SAndroid Build Coastguard Worker
1845*da0073e9SAndroid Build Coastguard Worker    def test_parameterlistdict_pickle(self):
1846*da0073e9SAndroid Build Coastguard Worker        m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
1847*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1848*da0073e9SAndroid Build Coastguard Worker            m = pickle.loads(pickle.dumps(m))
1849*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker        # Test whether loading from older checkpoints works without triggering warnings
1852*da0073e9SAndroid Build Coastguard Worker        m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)]))
1853*da0073e9SAndroid Build Coastguard Worker        del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
1854*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1855*da0073e9SAndroid Build Coastguard Worker            m = pickle.loads(pickle.dumps(m))
1856*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1857*da0073e9SAndroid Build Coastguard Worker
1858*da0073e9SAndroid Build Coastguard Worker        m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
1859*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1860*da0073e9SAndroid Build Coastguard Worker            m = pickle.loads(pickle.dumps(m))
1861*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker        # Test whether loading from older checkpoints works without triggering warnings
1864*da0073e9SAndroid Build Coastguard Worker        m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))})
1865*da0073e9SAndroid Build Coastguard Worker        del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set
1866*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
1867*da0073e9SAndroid Build Coastguard Worker            m = pickle.loads(pickle.dumps(m))
1868*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(w) == 0)
1869*da0073e9SAndroid Build Coastguard Worker
1870*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm_pickle(self):
1871*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.weight_norm(nn.Linear(5, 7))
1872*da0073e9SAndroid Build Coastguard Worker        m = pickle.loads(pickle.dumps(m))
1873*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(m, nn.Linear)
1874*da0073e9SAndroid Build Coastguard Worker
1875*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
1876*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
1877*da0073e9SAndroid Build Coastguard Worker    def test_spectral_norm(self):
1878*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 5)
1879*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 7)
1880*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.spectral_norm(m)
1881*da0073e9SAndroid Build Coastguard Worker
1882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.weight_u.size(), torch.Size([m.weight.size(0)]))
1883*da0073e9SAndroid Build Coastguard Worker        # weight_orig should be trainable
1884*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(m, 'weight_orig'))
1885*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('weight_orig' in m._parameters)
1886*da0073e9SAndroid Build Coastguard Worker        # weight_u should be just a reused buffer
1887*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(m, 'weight_u'))
1888*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('weight_u' in m._buffers)
1889*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('weight_v' in m._buffers)
1890*da0073e9SAndroid Build Coastguard Worker        # weight should be a plain attribute, not counted as a buffer or a param
1891*da0073e9SAndroid Build Coastguard Worker        self.assertFalse('weight' in m._buffers)
1892*da0073e9SAndroid Build Coastguard Worker        self.assertFalse('weight' in m._parameters)
1893*da0073e9SAndroid Build Coastguard Worker        # it should also be sharing storage as `weight_orig`
1894*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.weight_orig.storage(), m.weight.storage())
1895*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.weight_orig.size(), m.weight.size())
1896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.weight_orig.stride(), m.weight.stride())
1897*da0073e9SAndroid Build Coastguard Worker
1898*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.remove_spectral_norm(m)
1899*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(m, 'weight_orig'))
1900*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(m, 'weight_u'))
1901*da0073e9SAndroid Build Coastguard Worker        # weight should be converted back as a parameter
1902*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(m, 'weight'))
1903*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('weight' in m._parameters)
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'register two spectral_norm hooks'):
1906*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.spectral_norm(m)
1907*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.spectral_norm(m)
1908*da0073e9SAndroid Build Coastguard Worker
1909*da0073e9SAndroid Build Coastguard Worker        # test correctness in training/eval modes and cpu/multi-gpu settings
1910*da0073e9SAndroid Build Coastguard Worker        for apply_dp in (True, False):
1911*da0073e9SAndroid Build Coastguard Worker            if apply_dp:
1912*da0073e9SAndroid Build Coastguard Worker                if not TEST_MULTIGPU:
1913*da0073e9SAndroid Build Coastguard Worker                    continue
1914*da0073e9SAndroid Build Coastguard Worker                device = torch.device('cuda:0')
1915*da0073e9SAndroid Build Coastguard Worker
1916*da0073e9SAndroid Build Coastguard Worker                def maybe_wrap(m):
1917*da0073e9SAndroid Build Coastguard Worker                    return torch.nn.DataParallel(m, [0, 1])
1918*da0073e9SAndroid Build Coastguard Worker            else:
1919*da0073e9SAndroid Build Coastguard Worker                device = torch.device('cpu')
1920*da0073e9SAndroid Build Coastguard Worker
1921*da0073e9SAndroid Build Coastguard Worker                def maybe_wrap(m):
1922*da0073e9SAndroid Build Coastguard Worker                    return m
1923*da0073e9SAndroid Build Coastguard Worker
1924*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
1925*da0073e9SAndroid Build Coastguard Worker                m = nn.Linear(3, 4).to(device)
1926*da0073e9SAndroid Build Coastguard Worker                m.weight.requires_grad_(requires_grad)
1927*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.spectral_norm(m)
1928*da0073e9SAndroid Build Coastguard Worker                wrapped_m = maybe_wrap(m)
1929*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(m, 'weight_u'))
1930*da0073e9SAndroid Build Coastguard Worker                u0 = m.weight_u.clone()
1931*da0073e9SAndroid Build Coastguard Worker                v0 = m.weight_v.clone()
1932*da0073e9SAndroid Build Coastguard Worker
1933*da0073e9SAndroid Build Coastguard Worker                # TEST TRAINING BEHAVIOR
1934*da0073e9SAndroid Build Coastguard Worker
1935*da0073e9SAndroid Build Coastguard Worker                # assert that u and v are updated
1936*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(2, 3, device=device)
1937*da0073e9SAndroid Build Coastguard Worker                out = wrapped_m(input)
1938*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(u0, m.weight_u)
1939*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(v0, m.weight_v)
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker                # assert that backprop reaches weight_orig
1942*da0073e9SAndroid Build Coastguard Worker                # can't use gradcheck because the function changes as we
1943*da0073e9SAndroid Build Coastguard Worker                # activate through it in training mode
1944*da0073e9SAndroid Build Coastguard Worker                if requires_grad:
1945*da0073e9SAndroid Build Coastguard Worker                    torch.autograd.grad(out.sum(), m.weight_orig)
1946*da0073e9SAndroid Build Coastguard Worker
1947*da0073e9SAndroid Build Coastguard Worker                # test backward works with multiple forwards
1948*da0073e9SAndroid Build Coastguard Worker                # it uses training mode so we need to reset `u` and `v` vectors
1949*da0073e9SAndroid Build Coastguard Worker                # to same value at beginning for finite difference test to pass
1950*da0073e9SAndroid Build Coastguard Worker                saved_u = m.weight_u.clone()
1951*da0073e9SAndroid Build Coastguard Worker                saved_v = m.weight_v.clone()
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker                def fn(input):
1954*da0073e9SAndroid Build Coastguard Worker                    m.weight_u.data.copy_(saved_u)
1955*da0073e9SAndroid Build Coastguard Worker                    m.weight_v.data.copy_(saved_v)
1956*da0073e9SAndroid Build Coastguard Worker                    out0 = wrapped_m(input)
1957*da0073e9SAndroid Build Coastguard Worker                    out1 = wrapped_m(input)
1958*da0073e9SAndroid Build Coastguard Worker                    return out0 + out1
1959*da0073e9SAndroid Build Coastguard Worker
1960*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False)
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker                # test removing
1963*da0073e9SAndroid Build Coastguard Worker                pre_remove_out = wrapped_m(input)
1964*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.remove_spectral_norm(m)
1965*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(wrapped_m(input), pre_remove_out)
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.spectral_norm(m)
1968*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
1969*da0073e9SAndroid Build Coastguard Worker                    pre_remove_out = wrapped_m(input)
1970*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.remove_spectral_norm(m)
1971*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(wrapped_m(input), pre_remove_out)
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker                # TEST EVAL BEHAVIOR
1974*da0073e9SAndroid Build Coastguard Worker
1975*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.spectral_norm(m)
1976*da0073e9SAndroid Build Coastguard Worker                wrapped_m(input)
1977*da0073e9SAndroid Build Coastguard Worker                last_train_out = wrapped_m(input)
1978*da0073e9SAndroid Build Coastguard Worker                last_train_u = m.weight_u.clone()
1979*da0073e9SAndroid Build Coastguard Worker                last_train_v = m.weight_v.clone()
1980*da0073e9SAndroid Build Coastguard Worker                wrapped_m.zero_grad()
1981*da0073e9SAndroid Build Coastguard Worker                wrapped_m.eval()
1982*da0073e9SAndroid Build Coastguard Worker
1983*da0073e9SAndroid Build Coastguard Worker                eval_out0 = wrapped_m(input)
1984*da0073e9SAndroid Build Coastguard Worker                # assert eval gives same result as last training iteration
1985*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(eval_out0, last_train_out)
1986*da0073e9SAndroid Build Coastguard Worker                # assert doing more iteartion in eval don't change things
1987*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(eval_out0, wrapped_m(input))
1988*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(last_train_u, m.weight_u)
1989*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(last_train_v, m.weight_v)
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker                # FIXME: the code below is flaky when executed with DataParallel
1992*da0073e9SAndroid Build Coastguard Worker                # see https://github.com/pytorch/pytorch/issues/13818
1993*da0073e9SAndroid Build Coastguard Worker                if apply_dp:
1994*da0073e9SAndroid Build Coastguard Worker                    continue
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker                # test backward works with multiple forwards in mixed training
1997*da0073e9SAndroid Build Coastguard Worker                # and eval modes
1998*da0073e9SAndroid Build Coastguard Worker                # it uses training mode so we need to reset `u` and `v` vectors
1999*da0073e9SAndroid Build Coastguard Worker                # to same value at beginning for finite difference test to pass
2000*da0073e9SAndroid Build Coastguard Worker                saved_u = m.weight_u.clone()
2001*da0073e9SAndroid Build Coastguard Worker                saved_v = m.weight_v.clone()
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker                def fn(input):
2004*da0073e9SAndroid Build Coastguard Worker                    m.weight_u.data.copy_(saved_u)
2005*da0073e9SAndroid Build Coastguard Worker                    m.weight_v.data.copy_(saved_v)
2006*da0073e9SAndroid Build Coastguard Worker                    wrapped_m.train()
2007*da0073e9SAndroid Build Coastguard Worker                    out0 = wrapped_m(input)
2008*da0073e9SAndroid Build Coastguard Worker                    wrapped_m.eval()
2009*da0073e9SAndroid Build Coastguard Worker                    out1 = wrapped_m(input)
2010*da0073e9SAndroid Build Coastguard Worker                    wrapped_m.train()
2011*da0073e9SAndroid Build Coastguard Worker                    out2 = wrapped_m(input)
2012*da0073e9SAndroid Build Coastguard Worker                    wrapped_m.eval()
2013*da0073e9SAndroid Build Coastguard Worker                    out3 = wrapped_m(input)
2014*da0073e9SAndroid Build Coastguard Worker                    return out0 + out1 + out2 + out3
2015*da0073e9SAndroid Build Coastguard Worker
2016*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (input.clone().requires_grad_(),))
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker                # assert that backprop reaches weight_orig in eval
2019*da0073e9SAndroid Build Coastguard Worker                if requires_grad:
2020*da0073e9SAndroid Build Coastguard Worker                    def fn(weight):
2021*da0073e9SAndroid Build Coastguard Worker                        return wrapped_m(input)
2022*da0073e9SAndroid Build Coastguard Worker
2023*da0073e9SAndroid Build Coastguard Worker                    gradcheck(fn, (m.weight_orig,))
2024*da0073e9SAndroid Build Coastguard Worker
2025*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
2026*da0073e9SAndroid Build Coastguard Worker    def test_spectral_norm_load_state_dict(self):
2027*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 3)
2028*da0073e9SAndroid Build Coastguard Worker        for activate_times in (0, 3):
2029*da0073e9SAndroid Build Coastguard Worker            # Test backward compatibility
2030*da0073e9SAndroid Build Coastguard Worker            # At version None -> 1: weight becomes not a buffer and v vector becomes a buffer
2031*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(3, 5)
2032*da0073e9SAndroid Build Coastguard Worker            snm = torch.nn.utils.spectral_norm(m)
2033*da0073e9SAndroid Build Coastguard Worker            snm.train()
2034*da0073e9SAndroid Build Coastguard Worker            for _ in range(activate_times):
2035*da0073e9SAndroid Build Coastguard Worker                snm(inp)
2036*da0073e9SAndroid Build Coastguard Worker
2037*da0073e9SAndroid Build Coastguard Worker            version_latest_ref_state_dict = deepcopy(snm.state_dict())
2038*da0073e9SAndroid Build Coastguard Worker            self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_latest_ref_state_dict.keys()))
2039*da0073e9SAndroid Build Coastguard Worker
2040*da0073e9SAndroid Build Coastguard Worker            # test that non-strict loading works
2041*da0073e9SAndroid Build Coastguard Worker            non_strict_state_dict = deepcopy(version_latest_ref_state_dict)
2042*da0073e9SAndroid Build Coastguard Worker            non_strict_state_dict['nonsense'] = 'nonsense'
2043*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'):
2044*da0073e9SAndroid Build Coastguard Worker                snm.load_state_dict(non_strict_state_dict, strict=True)
2045*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2046*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict['weight_orig']
2047*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2048*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict['weight_u']
2049*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2050*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict['weight_v']
2051*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2052*da0073e9SAndroid Build Coastguard Worker            non_strict_state_dict['weight'] = snm.weight.detach().clone()  # set W as a buffer
2053*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2054*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict._metadata['']['spectral_norm']       # remove metadata info
2055*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2056*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict['weight']                            # remove W buffer
2057*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2058*da0073e9SAndroid Build Coastguard Worker            del non_strict_state_dict['bias']
2059*da0073e9SAndroid Build Coastguard Worker            snm.load_state_dict(non_strict_state_dict, strict=False)
2060*da0073e9SAndroid Build Coastguard Worker
2061*da0073e9SAndroid Build Coastguard Worker            # craft a version None state_dict
2062*da0073e9SAndroid Build Coastguard Worker            version_none_state_dict = deepcopy(version_latest_ref_state_dict)
2063*da0073e9SAndroid Build Coastguard Worker            self.assertIn('spectral_norm', version_none_state_dict._metadata[''])
2064*da0073e9SAndroid Build Coastguard Worker            del version_none_state_dict._metadata['']['spectral_norm']       # remove metadata info
2065*da0073e9SAndroid Build Coastguard Worker            del version_none_state_dict['weight_v']                          # remove v vector
2066*da0073e9SAndroid Build Coastguard Worker            version_none_state_dict['weight'] = snm.weight.detach().clone()  # set W as a buffer
2067*da0073e9SAndroid Build Coastguard Worker
2068*da0073e9SAndroid Build Coastguard Worker            # normal state_dict
2069*da0073e9SAndroid Build Coastguard Worker            for version_latest_with_metadata in [True, False]:
2070*da0073e9SAndroid Build Coastguard Worker                version_latest_state_dict = deepcopy(version_latest_ref_state_dict)
2071*da0073e9SAndroid Build Coastguard Worker
2072*da0073e9SAndroid Build Coastguard Worker                if not version_latest_with_metadata:
2073*da0073e9SAndroid Build Coastguard Worker                    # We want to still load a user-crafted state_dict, one without metadata
2074*da0073e9SAndroid Build Coastguard Worker                    del version_latest_state_dict._metadata['']['spectral_norm']
2075*da0073e9SAndroid Build Coastguard Worker
2076*da0073e9SAndroid Build Coastguard Worker                # test that re-wrapping does not matter
2077*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.remove_spectral_norm(snm)
2078*da0073e9SAndroid Build Coastguard Worker                snm = torch.nn.utils.spectral_norm(m)
2079*da0073e9SAndroid Build Coastguard Worker
2080*da0073e9SAndroid Build Coastguard Worker                snm.load_state_dict(version_latest_ref_state_dict)
2081*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
2082*da0073e9SAndroid Build Coastguard Worker                    snm.eval()
2083*da0073e9SAndroid Build Coastguard Worker                    out0_eval = snm(inp)
2084*da0073e9SAndroid Build Coastguard Worker                    snm.train()
2085*da0073e9SAndroid Build Coastguard Worker                    out1_train = snm(inp)
2086*da0073e9SAndroid Build Coastguard Worker                    out2_train = snm(inp)
2087*da0073e9SAndroid Build Coastguard Worker                    snm.eval()
2088*da0073e9SAndroid Build Coastguard Worker                    out3_eval = snm(inp)
2089*da0073e9SAndroid Build Coastguard Worker
2090*da0073e9SAndroid Build Coastguard Worker                # test that re-wrapping does not matter
2091*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.remove_spectral_norm(snm)
2092*da0073e9SAndroid Build Coastguard Worker                snm = torch.nn.utils.spectral_norm(m)
2093*da0073e9SAndroid Build Coastguard Worker
2094*da0073e9SAndroid Build Coastguard Worker                snm.load_state_dict(version_none_state_dict)
2095*da0073e9SAndroid Build Coastguard Worker                if activate_times > 0:
2096*da0073e9SAndroid Build Coastguard Worker                    # since in loading version None state dict, we assume that the
2097*da0073e9SAndroid Build Coastguard Worker                    # values in the state dict have gone through at lease one
2098*da0073e9SAndroid Build Coastguard Worker                    # forward, we only test for equivalence when activate_times > 0.
2099*da0073e9SAndroid Build Coastguard Worker                    with torch.no_grad():
2100*da0073e9SAndroid Build Coastguard Worker                        snm.eval()
2101*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(out0_eval, snm(inp))
2102*da0073e9SAndroid Build Coastguard Worker                        snm.train()
2103*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(out1_train, snm(inp))
2104*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(out2_train, snm(inp))
2105*da0073e9SAndroid Build Coastguard Worker                        snm.eval()
2106*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(out3_eval, snm(inp))
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker                # test that re-wrapping does not matter
2109*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.utils.remove_spectral_norm(snm)
2110*da0073e9SAndroid Build Coastguard Worker                snm = torch.nn.utils.spectral_norm(m)
2111*da0073e9SAndroid Build Coastguard Worker
2112*da0073e9SAndroid Build Coastguard Worker                # Test normal loading
2113*da0073e9SAndroid Build Coastguard Worker                snm.load_state_dict(version_latest_state_dict)
2114*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
2115*da0073e9SAndroid Build Coastguard Worker                    snm.eval()
2116*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out0_eval, snm(inp))
2117*da0073e9SAndroid Build Coastguard Worker                    snm.train()
2118*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out1_train, snm(inp))
2119*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out2_train, snm(inp))
2120*da0073e9SAndroid Build Coastguard Worker                    snm.eval()
2121*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out3_eval, snm(inp))
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker    def test_spectral_norm_dim(self):
2124*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 3, 10, 12)
2125*da0073e9SAndroid Build Coastguard Worker        m = nn.ConvTranspose2d(3, 4, (5, 6))
2126*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.spectral_norm(m)
2127*da0073e9SAndroid Build Coastguard Worker        # this should not run into incompatible shapes
2128*da0073e9SAndroid Build Coastguard Worker        x = m(inp)
2129*da0073e9SAndroid Build Coastguard Worker        # check that u refers to the same dimension
2130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.weight_u.shape, m.weight_orig[0, :, 0, 0].shape)
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker    def test_spectral_norm_forward(self):
2133*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 5)
2134*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 7)
2135*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.spectral_norm(m)
2136*da0073e9SAndroid Build Coastguard Worker        # naive forward
2137*da0073e9SAndroid Build Coastguard Worker        _weight, _bias, _u = m.weight_orig, m.bias, m.weight_u
2138*da0073e9SAndroid Build Coastguard Worker        _weight_mat = _weight.view(_weight.size(0), -1)
2139*da0073e9SAndroid Build Coastguard Worker        _v = torch.mv(_weight_mat.t(), _u)
2140*da0073e9SAndroid Build Coastguard Worker        _v = F.normalize(_v, dim=0, eps=1e-12)
2141*da0073e9SAndroid Build Coastguard Worker        _u = torch.mv(_weight_mat, _v)
2142*da0073e9SAndroid Build Coastguard Worker        _u = F.normalize(_u, dim=0, eps=1e-12)
2143*da0073e9SAndroid Build Coastguard Worker        _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
2144*da0073e9SAndroid Build Coastguard Worker        out_hat = torch.nn.functional.linear(input, _weight, _bias)
2145*da0073e9SAndroid Build Coastguard Worker        expect_out = m(input)
2146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect_out, out_hat)
2147*da0073e9SAndroid Build Coastguard Worker
2148*da0073e9SAndroid Build Coastguard Worker    def test_spectral_norm_pickle(self):
2149*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.utils.spectral_norm(nn.Linear(5, 7))
2150*da0073e9SAndroid Build Coastguard Worker        m = pickle.loads(pickle.dumps(m))
2151*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(m, nn.Linear)
2152*da0073e9SAndroid Build Coastguard Worker
2153*da0073e9SAndroid Build Coastguard Worker    def test_threshold_int(self):
2154*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([-3, -2, -1, 0, 1, 2, 3])
2155*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([99, 99, 99, 99, 1, 2, 3])
2156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.threshold(x, 0, 99), expected)
2157*da0073e9SAndroid Build Coastguard Worker
2158*da0073e9SAndroid Build Coastguard Worker    def test_threshold_bfloat16_half(self):
2159*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(100)
2160*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.half]:
2161*da0073e9SAndroid Build Coastguard Worker            for threshold in [0, -0.5, 0.5, float('inf'), float('-inf'), float('nan')]:
2162*da0073e9SAndroid Build Coastguard Worker                expected = F.threshold(x, threshold, 0).to(dtype=dtype).float()
2163*da0073e9SAndroid Build Coastguard Worker                res_bf16 = F.threshold(x.to(dtype=dtype), threshold, 0).float()
2164*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res_bf16, expected)
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker    @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
2167*da0073e9SAndroid Build Coastguard Worker                         'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
2168*da0073e9SAndroid Build Coastguard Worker                         ' with instruction set support avx2 or newer.')
2169*da0073e9SAndroid Build Coastguard Worker    def test_fb_fc_packed(self):
2170*da0073e9SAndroid Build Coastguard Worker        X = np.random.rand(16, 16).astype(np.float32) - 0.5
2171*da0073e9SAndroid Build Coastguard Worker        W = np.random.rand(16, 16).astype(np.float32) - 0.5
2172*da0073e9SAndroid Build Coastguard Worker        b = np.random.rand(16).astype(np.float32) - 0.5
2173*da0073e9SAndroid Build Coastguard Worker
2174*da0073e9SAndroid Build Coastguard Worker        def fc_op(X, W, b):
2175*da0073e9SAndroid Build Coastguard Worker            return np.dot(X, W.T) + b
2176*da0073e9SAndroid Build Coastguard Worker
2177*da0073e9SAndroid Build Coastguard Worker        x_tensor = torch.tensor(X)
2178*da0073e9SAndroid Build Coastguard Worker        w_tensor = torch.tensor(W)
2179*da0073e9SAndroid Build Coastguard Worker        b_tensor = torch.tensor(b)
2180*da0073e9SAndroid Build Coastguard Worker        packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor)
2181*da0073e9SAndroid Build Coastguard Worker        actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor)
2182*da0073e9SAndroid Build Coastguard Worker        expected_output = fc_op(X, W, b)
2183*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3)
2184*da0073e9SAndroid Build Coastguard Worker
2185*da0073e9SAndroid Build Coastguard Worker    def test_pad_scalar_error(self):
2186*da0073e9SAndroid Build Coastguard Worker        inputs = torch.tensor(0., requires_grad=True)
2187*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1)))
2188*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1,)))
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_mask(self):
2191*da0073e9SAndroid Build Coastguard Worker        N, L, D = 10, 12, 14
2192*da0073e9SAndroid Build Coastguard Worker
2193*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(N, L, D)
2194*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones(N, L, dtype=torch.bool)
2195*da0073e9SAndroid Build Coastguard Worker        # Leave first row be all True to maintain the nt's size unchanged
2196*da0073e9SAndroid Build Coastguard Worker        for i in range(1, N):
2197*da0073e9SAndroid Build Coastguard Worker            end = torch.randint(1, L, size=()).item()
2198*da0073e9SAndroid Build Coastguard Worker            mask[i, end:] = False
2199*da0073e9SAndroid Build Coastguard Worker
2200*da0073e9SAndroid Build Coastguard Worker        nt = torch._nested_tensor_from_mask(input, mask)
2201*da0073e9SAndroid Build Coastguard Worker        input_convert = nt.to_padded_tensor(0.)
2202*da0073e9SAndroid Build Coastguard Worker        input.masked_fill_(mask.reshape(N, L, 1).logical_not(), 0.)
2203*da0073e9SAndroid Build Coastguard Worker
2204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, input_convert)
2205*da0073e9SAndroid Build Coastguard Worker
2206*da0073e9SAndroid Build Coastguard Worker    def test_nested_tensor_from_mask_error(self):
2207*da0073e9SAndroid Build Coastguard Worker        N, L, D = 10, 12, 14
2208*da0073e9SAndroid Build Coastguard Worker
2209*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(N, L, D)
2210*da0073e9SAndroid Build Coastguard Worker        # Mask is not bool
2211*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(N, L, dtype=torch.float)
2212*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2213*da0073e9SAndroid Build Coastguard Worker
2214*da0073e9SAndroid Build Coastguard Worker        # Mask size is not 2
2215*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(N, L, D, dtype=torch.bool)
2216*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2217*da0073e9SAndroid Build Coastguard Worker
2218*da0073e9SAndroid Build Coastguard Worker        # Input size is not 3
2219*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(N, L, dtype=torch.bool)
2220*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(N, L)
2221*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2222*da0073e9SAndroid Build Coastguard Worker
2223*da0073e9SAndroid Build Coastguard Worker        # Mask size does not match input
2224*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(N + 1, L + 1, dtype=torch.bool)
2225*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(N, L, D)
2226*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2227*da0073e9SAndroid Build Coastguard Worker
2228*da0073e9SAndroid Build Coastguard Worker        # Mask is not padding format
2229*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones(N, L, dtype=torch.bool)
2230*da0073e9SAndroid Build Coastguard Worker        mask[0, 0] = False
2231*da0073e9SAndroid Build Coastguard Worker        mask[0, 2] = False
2232*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask))
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker    def test_normalize(self):
2235*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(1, 3, 4, 4, requires_grad=True, dtype=torch.double)
2236*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
2237*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x: F.normalize(x, p=2, dim=-2), (inputs,)))
2238*da0073e9SAndroid Build Coastguard Worker
2239*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn((), requires_grad=True)
2240*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
2241*da0073e9SAndroid Build Coastguard Worker
2242*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
2243*da0073e9SAndroid Build Coastguard Worker    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
2244*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
2245*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_double_backwards_gpu(self):
2246*da0073e9SAndroid Build Coastguard Worker        tensors = (torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double),
2247*da0073e9SAndroid Build Coastguard Worker                   torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double),
2248*da0073e9SAndroid Build Coastguard Worker                   torch.randn(4, 4, device='cuda', requires_grad=True, dtype=torch.double))
2249*da0073e9SAndroid Build Coastguard Worker        # TODO(#50743): the following segfaults with check_batched_grad=True
2250*da0073e9SAndroid Build Coastguard Worker        _assertGradAndGradgradChecks(self, lambda *i: Broadcast.apply((0, 1), *i), tensors,
2251*da0073e9SAndroid Build Coastguard Worker                                     check_batched_grad=False)
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
2254*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_not_requiring_grad(self):
2255*da0073e9SAndroid Build Coastguard Worker        variables = [
2256*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 2, device='cuda', requires_grad=True),
2257*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 2, device='cuda', requires_grad=False),
2258*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 2, device='cuda', requires_grad=False),
2259*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 2, device='cuda', requires_grad=True),
2260*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 2, device='cuda', requires_grad=True),
2261*da0073e9SAndroid Build Coastguard Worker        ]
2262*da0073e9SAndroid Build Coastguard Worker        broadcasted_variables = Broadcast.apply((0, 1), *variables)
2263*da0073e9SAndroid Build Coastguard Worker        for output_idx, broadcasted_var in enumerate(broadcasted_variables):
2264*da0073e9SAndroid Build Coastguard Worker            input_var = variables[output_idx % len(variables)]
2265*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_var.requires_grad, broadcasted_var.requires_grad)
2266*da0073e9SAndroid Build Coastguard Worker
2267*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
2268*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_no_grad(self):
2269*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 2, dtype=torch.float32, requires_grad=True, device='cuda')
2270*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2271*da0073e9SAndroid Build Coastguard Worker            broadcasted = Broadcast.apply((0, 1), x)
2272*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.requires_grad)
2273*da0073e9SAndroid Build Coastguard Worker        for output in broadcasted:
2274*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(output.requires_grad)
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker    def test_state_dict(self):
2277*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(5, 5)
2278*da0073e9SAndroid Build Coastguard Worker        block = nn.Module()
2279*da0073e9SAndroid Build Coastguard Worker        block.conv = nn.Conv2d(3, 3, 3, bias=False)
2280*da0073e9SAndroid Build Coastguard Worker        net = nn.Module()
2281*da0073e9SAndroid Build Coastguard Worker        net.linear1 = l
2282*da0073e9SAndroid Build Coastguard Worker        net.linear2 = l
2283*da0073e9SAndroid Build Coastguard Worker        net.bn = nn.BatchNorm2d(2)
2284*da0073e9SAndroid Build Coastguard Worker        net.block = block
2285*da0073e9SAndroid Build Coastguard Worker        net.add_module('empty', None)
2286*da0073e9SAndroid Build Coastguard Worker
2287*da0073e9SAndroid Build Coastguard Worker        state_dict = net.state_dict()
2288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(state_dict), 10)
2289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(state_dict._metadata), 6)
2290*da0073e9SAndroid Build Coastguard Worker        self.assertIn('', state_dict._metadata)
2291*da0073e9SAndroid Build Coastguard Worker        self.assertIn('linear1', state_dict._metadata)
2292*da0073e9SAndroid Build Coastguard Worker        self.assertIn('linear1.weight', state_dict)
2293*da0073e9SAndroid Build Coastguard Worker        self.assertIn('linear1.bias', state_dict)
2294*da0073e9SAndroid Build Coastguard Worker        self.assertIn('linear2', state_dict._metadata)
2295*da0073e9SAndroid Build Coastguard Worker        self.assertIn('linear2.weight', state_dict)
2296*da0073e9SAndroid Build Coastguard Worker        self.assertIn('linear2.bias', state_dict)
2297*da0073e9SAndroid Build Coastguard Worker        self.assertIn('block', state_dict._metadata)
2298*da0073e9SAndroid Build Coastguard Worker        self.assertIn('block.conv', state_dict._metadata)
2299*da0073e9SAndroid Build Coastguard Worker        self.assertIn('block.conv.weight', state_dict)
2300*da0073e9SAndroid Build Coastguard Worker        self.assertIn('block.conv.weight', state_dict)
2301*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn('block.conv.bias', state_dict)
2302*da0073e9SAndroid Build Coastguard Worker        self.assertIn('bn', state_dict._metadata)
2303*da0073e9SAndroid Build Coastguard Worker        self.assertIn('bn.weight', state_dict)
2304*da0073e9SAndroid Build Coastguard Worker        self.assertIn('bn.bias', state_dict)
2305*da0073e9SAndroid Build Coastguard Worker        self.assertIn('bn.running_var', state_dict)
2306*da0073e9SAndroid Build Coastguard Worker        self.assertIn('bn.running_mean', state_dict)
2307*da0073e9SAndroid Build Coastguard Worker        self.assertIn('bn.num_batches_tracked', state_dict)
2308*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(any(k.startswith('empty') for k in state_dict.keys()))
2309*da0073e9SAndroid Build Coastguard Worker        for k, v in state_dict.items():
2310*da0073e9SAndroid Build Coastguard Worker            param = net
2311*da0073e9SAndroid Build Coastguard Worker            for component in k.split('.'):
2312*da0073e9SAndroid Build Coastguard Worker                param = getattr(param, component)
2313*da0073e9SAndroid Build Coastguard Worker                if isinstance(param, Parameter):
2314*da0073e9SAndroid Build Coastguard Worker                    param = param.data
2315*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v.data_ptr(), param.data_ptr())
2316*da0073e9SAndroid Build Coastguard Worker
2317*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(5, 5)
2318*da0073e9SAndroid Build Coastguard Worker        state_dict = l.state_dict()
2319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(state_dict), 2)
2320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(state_dict._metadata), 1)
2321*da0073e9SAndroid Build Coastguard Worker        self.assertIn('', state_dict._metadata)
2322*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(state_dict._metadata['']['version'] >= 0)
2323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(state_dict['weight'].data_ptr(), l.weight.data_ptr())
2324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(state_dict['bias'].data_ptr(), l.bias.data_ptr())
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker        # Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545
2327*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: l.state_dict(destination={}), "Should not warn kwarg destination w/o _metadata")
2328*da0073e9SAndroid Build Coastguard Worker
2329*da0073e9SAndroid Build Coastguard Worker    def test_extra_state(self):
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker        class SubModule(torch.nn.Module):
2332*da0073e9SAndroid Build Coastguard Worker            def __init__(self, foo):
2333*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2334*da0073e9SAndroid Build Coastguard Worker                self.foo = foo
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker            def get_extra_state(self):
2337*da0073e9SAndroid Build Coastguard Worker                return {
2338*da0073e9SAndroid Build Coastguard Worker                    'foo': self.foo
2339*da0073e9SAndroid Build Coastguard Worker                }
2340*da0073e9SAndroid Build Coastguard Worker
2341*da0073e9SAndroid Build Coastguard Worker            def set_extra_state(self, state):
2342*da0073e9SAndroid Build Coastguard Worker                self.foo = state['foo']
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2345*da0073e9SAndroid Build Coastguard Worker            def __init__(self, foo, bar):
2346*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2347*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule(foo)
2348*da0073e9SAndroid Build Coastguard Worker                self.bar = bar
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker            def get_extra_state(self):
2351*da0073e9SAndroid Build Coastguard Worker                return {
2352*da0073e9SAndroid Build Coastguard Worker                    'bar': self.bar
2353*da0073e9SAndroid Build Coastguard Worker                }
2354*da0073e9SAndroid Build Coastguard Worker
2355*da0073e9SAndroid Build Coastguard Worker            def set_extra_state(self, state):
2356*da0073e9SAndroid Build Coastguard Worker                self.bar = state['bar']
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Worker        # Ensure state_dict contains the extra state by loading it into another module.
2359*da0073e9SAndroid Build Coastguard Worker        m = MyModule(3, 'something')
2360*da0073e9SAndroid Build Coastguard Worker        m2 = MyModule(5, 'something else')
2361*da0073e9SAndroid Build Coastguard Worker        m2.load_state_dict(m.state_dict())
2362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.state_dict(), m2.state_dict())
2363*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m2.bar, m.bar)
2364*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m2.sub.foo, m.sub.foo)
2365*da0073e9SAndroid Build Coastguard Worker
2366*da0073e9SAndroid Build Coastguard Worker    def test_extra_state_non_dict(self):
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2369*da0073e9SAndroid Build Coastguard Worker            def __init__(self, foo):
2370*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2371*da0073e9SAndroid Build Coastguard Worker                self.foo = foo
2372*da0073e9SAndroid Build Coastguard Worker
2373*da0073e9SAndroid Build Coastguard Worker            def get_extra_state(self):
2374*da0073e9SAndroid Build Coastguard Worker                return self.foo
2375*da0073e9SAndroid Build Coastguard Worker
2376*da0073e9SAndroid Build Coastguard Worker            def set_extra_state(self, state):
2377*da0073e9SAndroid Build Coastguard Worker                self.foo = state
2378*da0073e9SAndroid Build Coastguard Worker
2379*da0073e9SAndroid Build Coastguard Worker        # Test various types of extra state.
2380*da0073e9SAndroid Build Coastguard Worker        for state in ('something', 5, MyModule(3)):
2381*da0073e9SAndroid Build Coastguard Worker            m = MyModule(state)
2382*da0073e9SAndroid Build Coastguard Worker            m2 = MyModule('something else')
2383*da0073e9SAndroid Build Coastguard Worker            m2.load_state_dict(m.state_dict())
2384*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.state_dict(), m2.state_dict())
2385*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m.foo, m2.foo)
2386*da0073e9SAndroid Build Coastguard Worker
2387*da0073e9SAndroid Build Coastguard Worker    def test_extra_state_missing_set_extra_state(self):
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2390*da0073e9SAndroid Build Coastguard Worker            def get_extra_state(self):
2391*da0073e9SAndroid Build Coastguard Worker                return {
2392*da0073e9SAndroid Build Coastguard Worker                    'foo': 5
2393*da0073e9SAndroid Build Coastguard Worker                }
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
2396*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Unexpected key'):
2397*da0073e9SAndroid Build Coastguard Worker            m.load_state_dict(m.state_dict())
2398*da0073e9SAndroid Build Coastguard Worker
2399*da0073e9SAndroid Build Coastguard Worker    def test_extra_state_missing_get_extra_state(self):
2400*da0073e9SAndroid Build Coastguard Worker
2401*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2402*da0073e9SAndroid Build Coastguard Worker            def set_extra_state(self):
2403*da0073e9SAndroid Build Coastguard Worker                pass
2404*da0073e9SAndroid Build Coastguard Worker
2405*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
2406*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Missing key'):
2407*da0073e9SAndroid Build Coastguard Worker            m.load_state_dict(m.state_dict())
2408*da0073e9SAndroid Build Coastguard Worker
2409*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
2410*da0073e9SAndroid Build Coastguard Worker    def test_parameter_assignment(self):
2411*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(5, 5)
2412*da0073e9SAndroid Build Coastguard Worker
2413*da0073e9SAndroid Build Coastguard Worker        def num_params():
2414*da0073e9SAndroid Build Coastguard Worker            return len(list(l.parameters()))
2415*da0073e9SAndroid Build Coastguard Worker
2416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_params(), 2)
2417*da0073e9SAndroid Build Coastguard Worker
2418*da0073e9SAndroid Build Coastguard Worker        new_param = Parameter(torch.randn(5, 5))
2419*da0073e9SAndroid Build Coastguard Worker        l.param_name = new_param
2420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_params(), 3)
2421*da0073e9SAndroid Build Coastguard Worker        self.assertObjectIn(new_param, l.parameters())
2422*da0073e9SAndroid Build Coastguard Worker
2423*da0073e9SAndroid Build Coastguard Worker        var = torch.randn(5, 5)
2424*da0073e9SAndroid Build Coastguard Worker        l.var_name = var
2425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_params(), 3)
2426*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn(id(var), map(id, l.parameters()))
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Worker        # Make sure Variables are not saved as parameters
2429*da0073e9SAndroid Build Coastguard Worker        l.variable_attr = torch.empty(5, 5)
2430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_params(), 3)
2431*da0073e9SAndroid Build Coastguard Worker        l.param_attr = Parameter(torch.empty(5, 5))
2432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_params(), 4)
2433*da0073e9SAndroid Build Coastguard Worker
2434*da0073e9SAndroid Build Coastguard Worker        # It shouldn't be possible to replace a parameter with a Variable
2435*da0073e9SAndroid Build Coastguard Worker        def assign_var():
2436*da0073e9SAndroid Build Coastguard Worker            l.param_attr = torch.empty(5, 5)
2437*da0073e9SAndroid Build Coastguard Worker
2438*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, assign_var)
2439*da0073e9SAndroid Build Coastguard Worker        # But replacing it with None should be fine
2440*da0073e9SAndroid Build Coastguard Worker        l.param_attr = None
2441*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_params(), 3)
2442*da0073e9SAndroid Build Coastguard Worker
2443*da0073e9SAndroid Build Coastguard Worker    def test_assignment(self):
2444*da0073e9SAndroid Build Coastguard Worker        l = nn.Module()
2445*da0073e9SAndroid Build Coastguard Worker        a = nn.Parameter(torch.randn(2))
2446*da0073e9SAndroid Build Coastguard Worker        b = nn.Parameter(torch.randn(3))
2447*da0073e9SAndroid Build Coastguard Worker        c = nn.Parameter(torch.randn(4))
2448*da0073e9SAndroid Build Coastguard Worker        q = nn.Linear(4, 4)
2449*da0073e9SAndroid Build Coastguard Worker        r = nn.Linear(5, 5)
2450*da0073e9SAndroid Build Coastguard Worker        w = nn.Linear(6, 6)
2451*da0073e9SAndroid Build Coastguard Worker
2452*da0073e9SAndroid Build Coastguard Worker        def test_assignments(get_list, a, b, c):
2453*da0073e9SAndroid Build Coastguard Worker            # Check that None can be shadowed
2454*da0073e9SAndroid Build Coastguard Worker            l.a = None
2455*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(l.a)
2456*da0073e9SAndroid Build Coastguard Worker            self.assertIn('a', l.__dict__)
2457*da0073e9SAndroid Build Coastguard Worker            l.a = a
2458*da0073e9SAndroid Build Coastguard Worker            self.assertIs(l.a, a)
2459*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(get_list(), [a])
2460*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn('a', l.__dict__)
2461*da0073e9SAndroid Build Coastguard Worker
2462*da0073e9SAndroid Build Coastguard Worker            # Assign second object
2463*da0073e9SAndroid Build Coastguard Worker            l.b = None
2464*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(l.b)
2465*da0073e9SAndroid Build Coastguard Worker            self.assertIn('b', l.__dict__)
2466*da0073e9SAndroid Build Coastguard Worker            l.b = b
2467*da0073e9SAndroid Build Coastguard Worker            self.assertIs(l.b, b)
2468*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(get_list(), [a, b])
2469*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn('b', l.__dict__)
2470*da0073e9SAndroid Build Coastguard Worker
2471*da0073e9SAndroid Build Coastguard Worker            # Remove and add the object back. Order should be unchanged.
2472*da0073e9SAndroid Build Coastguard Worker            l.a = None
2473*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(l.a)
2474*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(get_list(), [b])
2475*da0073e9SAndroid Build Coastguard Worker            l.a = a
2476*da0073e9SAndroid Build Coastguard Worker            self.assertIs(l.a, a)
2477*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(get_list(), [a, b])
2478*da0073e9SAndroid Build Coastguard Worker
2479*da0073e9SAndroid Build Coastguard Worker            # Replace object with another one. Order should be unchanged.
2480*da0073e9SAndroid Build Coastguard Worker            l.a = c
2481*da0073e9SAndroid Build Coastguard Worker            self.assertIs(l.a, c)
2482*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(get_list(), [c, b])
2483*da0073e9SAndroid Build Coastguard Worker
2484*da0073e9SAndroid Build Coastguard Worker            # Remove and reassign an attribute. It should appear at the end of the list now.
2485*da0073e9SAndroid Build Coastguard Worker            del l.a
2486*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(hasattr(l, 'a'))
2487*da0073e9SAndroid Build Coastguard Worker            l.a = a
2488*da0073e9SAndroid Build Coastguard Worker            self.assertIs(l.a, a)
2489*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(get_list(), [b, a])
2490*da0073e9SAndroid Build Coastguard Worker
2491*da0073e9SAndroid Build Coastguard Worker        test_assignments(lambda: list(l.parameters()), a, b, c)
2492*da0073e9SAndroid Build Coastguard Worker        del l.a, l.b
2493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(l.parameters()), [])
2494*da0073e9SAndroid Build Coastguard Worker
2495*da0073e9SAndroid Build Coastguard Worker        test_assignments(lambda: list(l.children()), q, r, w)
2496*da0073e9SAndroid Build Coastguard Worker        del l.a, l.b
2497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(l.children()), [])
2498*da0073e9SAndroid Build Coastguard Worker
2499*da0073e9SAndroid Build Coastguard Worker        buf = Buffer(torch.randn(10))
2500*da0073e9SAndroid Build Coastguard Worker        l.buf = buf
2501*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l.buf, buf)
2502*da0073e9SAndroid Build Coastguard Worker        l.buf = None
2503*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l.buf, None)
2504*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn('buf', l.__dict__)  # should be stored in l._buffers
2505*da0073e9SAndroid Build Coastguard Worker        l.buf = buf
2506*da0073e9SAndroid Build Coastguard Worker        self.assertIn('buf', l.state_dict())
2507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l.state_dict()['buf'], buf)
2508*da0073e9SAndroid Build Coastguard Worker
2509*da0073e9SAndroid Build Coastguard Worker    def test_container_copy(self):
2510*da0073e9SAndroid Build Coastguard Worker        class Model(nn.Module):
2511*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2512*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2513*da0073e9SAndroid Build Coastguard Worker                self.linear = nn.Linear(4, 5)
2514*da0073e9SAndroid Build Coastguard Worker
2515*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
2516*da0073e9SAndroid Build Coastguard Worker                return self.linear(input)
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 4)
2519*da0073e9SAndroid Build Coastguard Worker
2520*da0073e9SAndroid Build Coastguard Worker        model = Model()
2521*da0073e9SAndroid Build Coastguard Worker        model_cp = deepcopy(model)
2522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model(input).data, model_cp(input).data)
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker        model_cp.linear.weight.data[:] = 2
2525*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(model(input).data, model_cp(input).data)
2526*da0073e9SAndroid Build Coastguard Worker
2527*da0073e9SAndroid Build Coastguard Worker    def test_RNN_cell(self):
2528*da0073e9SAndroid Build Coastguard Worker        # this is just a smoke test; these modules are implemented through
2529*da0073e9SAndroid Build Coastguard Worker        # autograd so no Jacobian test is needed
2530*da0073e9SAndroid Build Coastguard Worker        for module in (nn.RNNCell, nn.GRUCell):
2531*da0073e9SAndroid Build Coastguard Worker            for bias in (True, False):
2532*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(3, 10)
2533*da0073e9SAndroid Build Coastguard Worker                hx = torch.randn(3, 20)
2534*da0073e9SAndroid Build Coastguard Worker                cell = module(10, 20, bias=bias)
2535*da0073e9SAndroid Build Coastguard Worker                for _ in range(6):
2536*da0073e9SAndroid Build Coastguard Worker                    hx = cell(input, hx)
2537*da0073e9SAndroid Build Coastguard Worker
2538*da0073e9SAndroid Build Coastguard Worker                hx.sum().backward()
2539*da0073e9SAndroid Build Coastguard Worker
2540*da0073e9SAndroid Build Coastguard Worker    def test_RNN_cell_forward_zero_hidden_size(self):
2541*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 10)
2542*da0073e9SAndroid Build Coastguard Worker        hx = torch.randn(3, 0)
2543*da0073e9SAndroid Build Coastguard Worker        cell_shared_param = (10, 0)
2544*da0073e9SAndroid Build Coastguard Worker        for cell in (nn.RNNCell(*cell_shared_param, nonlinearity="relu"),
2545*da0073e9SAndroid Build Coastguard Worker                     nn.RNNCell(*cell_shared_param, nonlinearity="tanh"),
2546*da0073e9SAndroid Build Coastguard Worker                     nn.GRUCell(*cell_shared_param)):
2547*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cell(input, hx).shape, torch.Size([3, 0]))
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker    def _test_loss_equal_input_target_shape(self, cast):
2550*da0073e9SAndroid Build Coastguard Worker        # Tests losses whose inputs should have the same size.
2551*da0073e9SAndroid Build Coastguard Worker        losses = {
2552*da0073e9SAndroid Build Coastguard Worker            'mse_loss': lambda x, y: F.mse_loss(x, y),
2553*da0073e9SAndroid Build Coastguard Worker            'l1_loss': lambda x, y: F.l1_loss(x, y),
2554*da0073e9SAndroid Build Coastguard Worker            'smooth_l1_loss': lambda x, y: F.smooth_l1_loss(x, y),
2555*da0073e9SAndroid Build Coastguard Worker            'huber_loss': lambda x, y: F.huber_loss(x, y),
2556*da0073e9SAndroid Build Coastguard Worker            'kl_div': lambda x, y: F.kl_div(x, y),
2557*da0073e9SAndroid Build Coastguard Worker            'poisson_nll_loss': lambda x, y: F.poisson_nll_loss(x, y),
2558*da0073e9SAndroid Build Coastguard Worker        }
2559*da0073e9SAndroid Build Coastguard Worker
2560*da0073e9SAndroid Build Coastguard Worker        input = cast(torch.randn(3, 5))
2561*da0073e9SAndroid Build Coastguard Worker        target = cast(torch.randn(5, 3))
2562*da0073e9SAndroid Build Coastguard Worker        for fn in losses.values():
2563*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(Exception, lambda: fn(input, target))
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker    def test_loss_equal_input_target_shape(self):
2566*da0073e9SAndroid Build Coastguard Worker        self._test_loss_equal_input_target_shape(lambda x: x)
2567*da0073e9SAndroid Build Coastguard Worker
2568*da0073e9SAndroid Build Coastguard Worker    def test_mse_loss_size_warning(self):
2569*da0073e9SAndroid Build Coastguard Worker        i = torch.randn((10, 1), requires_grad=True)
2570*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((10,))
2571*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
2572*da0073e9SAndroid Build Coastguard Worker            # Ensure warnings are being shown
2573*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
2574*da0073e9SAndroid Build Coastguard Worker            # Trigger Warning
2575*da0073e9SAndroid Build Coastguard Worker            F.mse_loss(i, t)
2576*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
2577*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
2578*da0073e9SAndroid Build Coastguard Worker            self.assertIn('Please ensure they have the same size.', str(w[0]))
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker    def test_gaussian_nll_loss_broadcasting(self):
2581*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
2582*da0073e9SAndroid Build Coastguard Worker        target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]])
2583*da0073e9SAndroid Build Coastguard Worker        target_part = torch.tensor([[1., 2., 3.]])
2584*da0073e9SAndroid Build Coastguard Worker        var_full = torch.tensor([[0.5, 0.5, 0.5], [1.5, 1.5, 1.5]])
2585*da0073e9SAndroid Build Coastguard Worker        var_part1 = torch.tensor([[0.5], [1.5]])
2586*da0073e9SAndroid Build Coastguard Worker        var_part2 = torch.tensor([0.5, 1.5])
2587*da0073e9SAndroid Build Coastguard Worker        component_wise_loss = 0.5 * (torch.log(var_full) + (input - target_full)**2 / var_full)
2588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(component_wise_loss,
2589*da0073e9SAndroid Build Coastguard Worker                         F.gaussian_nll_loss(input, target_part, var_full, reduction='none'))
2590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(component_wise_loss,
2591*da0073e9SAndroid Build Coastguard Worker                         F.gaussian_nll_loss(input, target_full, var_part1, reduction='none'))
2592*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(component_wise_loss,
2593*da0073e9SAndroid Build Coastguard Worker                         F.gaussian_nll_loss(input, target_full, var_part2, reduction='none'))
2594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(component_wise_loss,
2595*da0073e9SAndroid Build Coastguard Worker                         F.gaussian_nll_loss(input, target_part, var_part1, reduction='none'))
2596*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(component_wise_loss,
2597*da0073e9SAndroid Build Coastguard Worker                         F.gaussian_nll_loss(input, target_part, var_part2, reduction='none'))
2598*da0073e9SAndroid Build Coastguard Worker
2599*da0073e9SAndroid Build Coastguard Worker    def test_gaussian_nll_loss_args(self):
2600*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 5)
2601*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, 'var is of incorrect size'):
2602*da0073e9SAndroid Build Coastguard Worker            target = torch.randn(3, 5)
2603*da0073e9SAndroid Build Coastguard Worker            var = torch.ones(3, 3)
2604*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.gaussian_nll_loss(input, target, var)
2605*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, 'var has negative entry/entries'):
2606*da0073e9SAndroid Build Coastguard Worker            var = -1 * torch.ones(3, 5)
2607*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.gaussian_nll_loss(input, target, var)
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker    def test_KLDivLoss_batch_mean(self):
2610*da0073e9SAndroid Build Coastguard Worker        input_shape = (2, 5)
2611*da0073e9SAndroid Build Coastguard Worker        log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
2612*da0073e9SAndroid Build Coastguard Worker        prob2 = F.softmax(torch.randn(input_shape), 1)
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker        loss = nn.KLDivLoss(reduction='batchmean')
2615*da0073e9SAndroid Build Coastguard Worker        l = loss(log_prob1, prob2)
2616*da0073e9SAndroid Build Coastguard Worker
2617*da0073e9SAndroid Build Coastguard Worker        loss_none_reduce = nn.KLDivLoss(reduction='sum')(log_prob1, prob2)
2618*da0073e9SAndroid Build Coastguard Worker        expected = loss_none_reduce / input_shape[0]
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l, expected)
2621*da0073e9SAndroid Build Coastguard Worker
2622*da0073e9SAndroid Build Coastguard Worker    def test_KLDivLoss_batch_mean_log_target(self):
2623*da0073e9SAndroid Build Coastguard Worker        input_shape = (2, 5)
2624*da0073e9SAndroid Build Coastguard Worker        log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
2625*da0073e9SAndroid Build Coastguard Worker        log_prob2 = F.log_softmax(torch.randn(input_shape), 1)
2626*da0073e9SAndroid Build Coastguard Worker
2627*da0073e9SAndroid Build Coastguard Worker        loss = nn.KLDivLoss(reduction='batchmean', log_target=True)
2628*da0073e9SAndroid Build Coastguard Worker        l = loss(log_prob1, log_prob2)
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker        loss_none_reduce = nn.KLDivLoss(reduction='sum', log_target=True)(log_prob1, log_prob2)
2631*da0073e9SAndroid Build Coastguard Worker        expected = loss_none_reduce / input_shape[0]
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l, expected)
2634*da0073e9SAndroid Build Coastguard Worker
2635*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_typechecks(self):
2636*da0073e9SAndroid Build Coastguard Worker        target_lengths = torch.tensor([30, 25, 20])
2637*da0073e9SAndroid Build Coastguard Worker        input_lengths = torch.tensor([50, 50, 50])
2638*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int)
2639*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
2640*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
2641*da0073e9SAndroid Build Coastguard Worker            _input_lengths = input_lengths.to(dtype=torch.float)
2642*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.ctc_loss(log_probs, targets, _input_lengths, target_lengths)
2643*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
2644*da0073e9SAndroid Build Coastguard Worker            target_lengths = target_lengths.to(dtype=torch.float)
2645*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
2646*da0073e9SAndroid Build Coastguard Worker
2647*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2648*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_lengthchecks_cuda(self):
2649*da0073e9SAndroid Build Coastguard Worker        for target_lengths in [[30, 25, 20], [-1, -1, -1]]:
2650*da0073e9SAndroid Build Coastguard Worker            for input_lengths in [[50, 50, 50], [-1, -1, -1]]:
2651*da0073e9SAndroid Build Coastguard Worker                targets = torch.randint(1, 15, (3, 29), dtype=torch.long, device='cuda')
2652*da0073e9SAndroid Build Coastguard Worker                log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2)
2653*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError):
2654*da0073e9SAndroid Build Coastguard Worker                    torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
2655*da0073e9SAndroid Build Coastguard Worker
2656*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_lengthchecks_cpu(self):
2657*da0073e9SAndroid Build Coastguard Worker        for target_lengths in [[30, 25, 20], [-1, -1, -1]]:
2658*da0073e9SAndroid Build Coastguard Worker            for input_lengths in [[50, 50, 50], [-1, -1, -1]]:
2659*da0073e9SAndroid Build Coastguard Worker                targets = torch.randint(1, 15, (3, 29), dtype=torch.int)
2660*da0073e9SAndroid Build Coastguard Worker                log_probs = torch.randn(50, 3, 15, dtype=torch.float).log_softmax(2)
2661*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(RuntimeError):
2662*da0073e9SAndroid Build Coastguard Worker                    torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
2663*da0073e9SAndroid Build Coastguard Worker
2664*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2665*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_long_targets(self):
2666*da0073e9SAndroid Build Coastguard Worker        input_length = 4000
2667*da0073e9SAndroid Build Coastguard Worker        vocab_size = 3
2668*da0073e9SAndroid Build Coastguard Worker        batch_size = 4
2669*da0073e9SAndroid Build Coastguard Worker        target_length = 1200
2670*da0073e9SAndroid Build Coastguard Worker
2671*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.randn(input_length, batch_size, vocab_size, dtype=torch.double).log_softmax(2).requires_grad_()
2672*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(low=1, high=vocab_size - 1, size=(batch_size, target_length), dtype=torch.long)
2673*da0073e9SAndroid Build Coastguard Worker        input_lengths = batch_size * [input_length]
2674*da0073e9SAndroid Build Coastguard Worker        target_lengths = batch_size * [target_length]
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths,
2677*da0073e9SAndroid Build Coastguard Worker                                               reduction='sum', zero_infinity=True)
2678*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.randn_like(res_cpu)
2679*da0073e9SAndroid Build Coastguard Worker        grad_cpu, = torch.autograd.grad(res_cpu, log_probs, grad_out)
2680*da0073e9SAndroid Build Coastguard Worker
2681*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=False):
2682*da0073e9SAndroid Build Coastguard Worker            res_gpu = torch.nn.functional.ctc_loss(log_probs.cuda(), targets.cuda(), input_lengths, target_lengths,
2683*da0073e9SAndroid Build Coastguard Worker                                                   reduction='sum', zero_infinity=True)
2684*da0073e9SAndroid Build Coastguard Worker            grad_gpu, = torch.autograd.grad(res_gpu, log_probs, grad_out.cuda())
2685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_gpu, atol=1e-4, rtol=0)
2686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_cpu, grad_gpu, atol=1e-4, rtol=0)
2687*da0073e9SAndroid Build Coastguard Worker
2688*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2689*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_critical_target_len(self):
2690*da0073e9SAndroid Build Coastguard Worker        # cudnn has an unexpected problem with target length 256, see issue #53505
2691*da0073e9SAndroid Build Coastguard Worker        N = 1
2692*da0073e9SAndroid Build Coastguard Worker        S = 256
2693*da0073e9SAndroid Build Coastguard Worker        C = 10
2694*da0073e9SAndroid Build Coastguard Worker        T = 500
2695*da0073e9SAndroid Build Coastguard Worker        target = torch.randint(low=1, high=C, size=(S,), dtype=torch.int)
2696*da0073e9SAndroid Build Coastguard Worker        input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.int)
2697*da0073e9SAndroid Build Coastguard Worker        target_lengths = torch.tensor(S, dtype=torch.int)
2698*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(T, N, C, dtype=torch.float, device='cuda').log_softmax(2).requires_grad_()
2699*da0073e9SAndroid Build Coastguard Worker        with cudnn.flags(enabled=True):
2700*da0073e9SAndroid Build Coastguard Worker            res_gpu = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none')
2701*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.nn.functional.ctc_loss(inp.cpu(), target, input_lengths, target_lengths, reduction='none')
2702*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_gpu, atol=1e-3, rtol=0)
2703*da0073e9SAndroid Build Coastguard Worker
2704*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_zero_lengths(self):
2705*da0073e9SAndroid Build Coastguard Worker        devices = ['cpu']
2706*da0073e9SAndroid Build Coastguard Worker        devices += ['cuda'] if TEST_CUDA else []
2707*da0073e9SAndroid Build Coastguard Worker        N = 3
2708*da0073e9SAndroid Build Coastguard Worker        S = 2
2709*da0073e9SAndroid Build Coastguard Worker        C = 200
2710*da0073e9SAndroid Build Coastguard Worker        T = 1
2711*da0073e9SAndroid Build Coastguard Worker        target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.int)
2712*da0073e9SAndroid Build Coastguard Worker        input_lengths = torch.full(size=(N,), fill_value=0, dtype=torch.int)
2713*da0073e9SAndroid Build Coastguard Worker        target_lengths = torch.full(size=(N,), fill_value=0, dtype=torch.int)
2714*da0073e9SAndroid Build Coastguard Worker        for device in devices:
2715*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(T, N, C, dtype=torch.float, device=device).log_softmax(2).requires_grad_()
2716*da0073e9SAndroid Build Coastguard Worker            res = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none')
2717*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((res == 0).all().item())
2718*da0073e9SAndroid Build Coastguard Worker            res.sum().backward()
2719*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((inp.grad == 0).all().item())
2720*da0073e9SAndroid Build Coastguard Worker        target_lengths = torch.full(size=(N,), fill_value=1, dtype=torch.int)
2721*da0073e9SAndroid Build Coastguard Worker        for device in devices:
2722*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(T, N, C, dtype=torch.float, device=device).log_softmax(2).requires_grad_()
2723*da0073e9SAndroid Build Coastguard Worker            res = torch.nn.functional.ctc_loss(inp, target, input_lengths, target_lengths, reduction='none')
2724*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((res == torch.inf).all().item())
2725*da0073e9SAndroid Build Coastguard Worker            res.sum().backward()
2726*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((inp.grad == 0).all().item())
2727*da0073e9SAndroid Build Coastguard Worker
2728*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2729*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_zero_infinity(self):
2730*da0073e9SAndroid Build Coastguard Worker        target_lengths = [60, 25, 20]
2731*da0073e9SAndroid Build Coastguard Worker        input_lengths = [50, 50, 50]
2732*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int, device='cuda')
2733*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.randn(50, 3, 15, dtype=torch.float, device='cuda').log_softmax(2).requires_grad_()
2734*da0073e9SAndroid Build Coastguard Worker        res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths,
2735*da0073e9SAndroid Build Coastguard Worker                                           reduction='sum', zero_infinity=True)
2736*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=False):
2737*da0073e9SAndroid Build Coastguard Worker            res2 = torch.nn.functional.ctc_loss(log_probs, targets.cuda().long(), input_lengths, target_lengths,
2738*da0073e9SAndroid Build Coastguard Worker                                                reduction='sum', zero_infinity=True)
2739*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.nn.functional.ctc_loss(log_probs.cpu(), targets.cpu(), input_lengths, target_lengths,
2740*da0073e9SAndroid Build Coastguard Worker                                               reduction='sum', zero_infinity=True)
2741*da0073e9SAndroid Build Coastguard Worker
2742*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res2, res, atol=1e-4, rtol=0)
2743*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res.cpu(), atol=1e-4, rtol=0)
2744*da0073e9SAndroid Build Coastguard Worker        g1, = torch.autograd.grad(res, log_probs)
2745*da0073e9SAndroid Build Coastguard Worker        g2, = torch.autograd.grad(res2, log_probs)
2746*da0073e9SAndroid Build Coastguard Worker        g3, = torch.autograd.grad(res_cpu, log_probs)
2747*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g2, g3, atol=1e-4, rtol=0)
2748*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g1, g2, atol=1e-4, rtol=0)
2749*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((g1 == g1).all().item())  # check that we don't have NaN
2750*da0073e9SAndroid Build Coastguard Worker
2751*da0073e9SAndroid Build Coastguard Worker    def test_RNN_cell_no_broadcasting(self):
2752*da0073e9SAndroid Build Coastguard Worker        def test(cell_module, input, hx, input_size, hidden_size):
2753*da0073e9SAndroid Build Coastguard Worker            cell = cell_module(input_size, hidden_size)
2754*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: cell(input, hx))
2755*da0073e9SAndroid Build Coastguard Worker
2756*da0073e9SAndroid Build Coastguard Worker        def test_all(hidden_size, bad_hx, good_hx, input_size, input):
2757*da0073e9SAndroid Build Coastguard Worker            test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
2758*da0073e9SAndroid Build Coastguard Worker            test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
2759*da0073e9SAndroid Build Coastguard Worker            test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
2760*da0073e9SAndroid Build Coastguard Worker            test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
2761*da0073e9SAndroid Build Coastguard Worker
2762*da0073e9SAndroid Build Coastguard Worker        hidden_size = 20
2763*da0073e9SAndroid Build Coastguard Worker        input_size = 10
2764*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, input_size)
2765*da0073e9SAndroid Build Coastguard Worker        bad_hx = torch.randn(1, hidden_size)
2766*da0073e9SAndroid Build Coastguard Worker        good_hx = torch.randn(3, hidden_size)
2767*da0073e9SAndroid Build Coastguard Worker
2768*da0073e9SAndroid Build Coastguard Worker        # Test hidden/input batch size broadcasting
2769*da0073e9SAndroid Build Coastguard Worker        test_all(hidden_size, bad_hx, good_hx, input_size, input)
2770*da0073e9SAndroid Build Coastguard Worker
2771*da0073e9SAndroid Build Coastguard Worker        # Test hx's hidden_size vs module's hidden_size broadcasting
2772*da0073e9SAndroid Build Coastguard Worker        bad_hx = torch.randn(3, 1)
2773*da0073e9SAndroid Build Coastguard Worker        test_all(hidden_size, bad_hx, good_hx, input_size, input)
2774*da0073e9SAndroid Build Coastguard Worker
2775*da0073e9SAndroid Build Coastguard Worker        # Test input's input_size vs module's input_size broadcasting
2776*da0073e9SAndroid Build Coastguard Worker        bad_input = torch.randn(3, 1)
2777*da0073e9SAndroid Build Coastguard Worker        test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
2778*da0073e9SAndroid Build Coastguard Worker
2779*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_cell(self):
2780*da0073e9SAndroid Build Coastguard Worker        # this is just a smoke test; these modules are implemented through
2781*da0073e9SAndroid Build Coastguard Worker        # autograd so no Jacobian test is needed
2782*da0073e9SAndroid Build Coastguard Worker        for bias in (True, False):
2783*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 10)
2784*da0073e9SAndroid Build Coastguard Worker            hx = torch.randn(3, 20)
2785*da0073e9SAndroid Build Coastguard Worker            cx = torch.randn(3, 20)
2786*da0073e9SAndroid Build Coastguard Worker            lstm = nn.LSTMCell(10, 20, bias=bias)
2787*da0073e9SAndroid Build Coastguard Worker            for _ in range(6):
2788*da0073e9SAndroid Build Coastguard Worker                hx, cx = lstm(input, (hx, cx))
2789*da0073e9SAndroid Build Coastguard Worker
2790*da0073e9SAndroid Build Coastguard Worker            (hx + cx).sum().backward()
2791*da0073e9SAndroid Build Coastguard Worker
2792*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_cell_forward_input_size(self):
2793*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 11)
2794*da0073e9SAndroid Build Coastguard Worker        hx = torch.randn(3, 20)
2795*da0073e9SAndroid Build Coastguard Worker        cx = torch.randn(3, 20)
2796*da0073e9SAndroid Build Coastguard Worker        lstm = nn.LSTMCell(10, 20)
2797*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
2798*da0073e9SAndroid Build Coastguard Worker
2799*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_cell_forward_hidden_size(self):
2800*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 10)
2801*da0073e9SAndroid Build Coastguard Worker        hx = torch.randn(3, 21)
2802*da0073e9SAndroid Build Coastguard Worker        cx = torch.randn(3, 20)
2803*da0073e9SAndroid Build Coastguard Worker        lstm = nn.LSTMCell(10, 20)
2804*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
2805*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker
2808*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
2809*da0073e9SAndroid Build Coastguard Worker    def test_pack_sequence_batch_sizes_throw(self):
2810*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, r"batch_sizes should always be on CPU"):
2811*da0073e9SAndroid Build Coastguard Worker            m = nn.LSTM(3, 4, bidirectional=True, num_layers=2).to('cuda')
2812*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(5, 3, device='cuda')
2813*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([1, 1, 1, 1, 1], device='cuda')
2814*da0073e9SAndroid Build Coastguard Worker            input = nn.utils.rnn.PackedSequence(a, b)
2815*da0073e9SAndroid Build Coastguard Worker
2816*da0073e9SAndroid Build Coastguard Worker    def test_Transformer_cell(self):
2817*da0073e9SAndroid Build Coastguard Worker        # this is just a smoke test; these modules are implemented through
2818*da0073e9SAndroid Build Coastguard Worker        # autograd so no Jacobian test is needed
2819*da0073e9SAndroid Build Coastguard Worker        d_model = 512
2820*da0073e9SAndroid Build Coastguard Worker        nhead = 16
2821*da0073e9SAndroid Build Coastguard Worker        num_encoder_layers = 4
2822*da0073e9SAndroid Build Coastguard Worker        num_decoder_layers = 3
2823*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 256
2824*da0073e9SAndroid Build Coastguard Worker        dropout = 0.3
2825*da0073e9SAndroid Build Coastguard Worker        bsz = 8
2826*da0073e9SAndroid Build Coastguard Worker        seq_length = 35
2827*da0073e9SAndroid Build Coastguard Worker        tgt_length = 15
2828*da0073e9SAndroid Build Coastguard Worker        for batch_first, src_size, tgt_size in zip((True, False),
2829*da0073e9SAndroid Build Coastguard Worker                                                   [(bsz, seq_length, d_model),
2830*da0073e9SAndroid Build Coastguard Worker                                                    (seq_length, bsz, d_model)],
2831*da0073e9SAndroid Build Coastguard Worker                                                   [(bsz, tgt_length, d_model),
2832*da0073e9SAndroid Build Coastguard Worker                                                    (tgt_length, bsz, d_model)]):
2833*da0073e9SAndroid Build Coastguard Worker            transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers,
2834*da0073e9SAndroid Build Coastguard Worker                                         dim_feedforward, dropout, batch_first=batch_first,
2835*da0073e9SAndroid Build Coastguard Worker                                         dtype=torch.double)
2836*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(src_size, dtype=torch.double)
2837*da0073e9SAndroid Build Coastguard Worker            src_mask = transformer.generate_square_subsequent_mask(seq_length).double()
2838*da0073e9SAndroid Build Coastguard Worker            tgt = torch.randn(tgt_size, dtype=torch.double)
2839*da0073e9SAndroid Build Coastguard Worker            tgt_mask = transformer.generate_square_subsequent_mask(tgt_length).double()
2840*da0073e9SAndroid Build Coastguard Worker            memory_mask = torch.randn(tgt_length, seq_length).double()
2841*da0073e9SAndroid Build Coastguard Worker            src_key_padding_mask = torch.rand(bsz, seq_length) >= 0.5
2842*da0073e9SAndroid Build Coastguard Worker            tgt_key_padding_mask = torch.rand(bsz, tgt_length) >= 0.5
2843*da0073e9SAndroid Build Coastguard Worker            memory_key_padding_mask = torch.rand(bsz, seq_length) >= 0.5
2844*da0073e9SAndroid Build Coastguard Worker
2845*da0073e9SAndroid Build Coastguard Worker            output = transformer(src, tgt,
2846*da0073e9SAndroid Build Coastguard Worker                                 src_mask=src_mask,
2847*da0073e9SAndroid Build Coastguard Worker                                 tgt_mask=tgt_mask,
2848*da0073e9SAndroid Build Coastguard Worker                                 memory_mask=memory_mask,
2849*da0073e9SAndroid Build Coastguard Worker                                 src_key_padding_mask=src_key_padding_mask,
2850*da0073e9SAndroid Build Coastguard Worker                                 tgt_key_padding_mask=tgt_key_padding_mask,
2851*da0073e9SAndroid Build Coastguard Worker                                 memory_key_padding_mask=memory_key_padding_mask)
2852*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
2853*da0073e9SAndroid Build Coastguard Worker
2854*da0073e9SAndroid Build Coastguard Worker    def test_transformerdecoderlayer(self):
2855*da0073e9SAndroid Build Coastguard Worker        # this is a deterministic test for TransformerDecoderLayer
2856*da0073e9SAndroid Build Coastguard Worker        d_model = 4
2857*da0073e9SAndroid Build Coastguard Worker        nhead = 2
2858*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 16
2859*da0073e9SAndroid Build Coastguard Worker        dropout = 0.0
2860*da0073e9SAndroid Build Coastguard Worker        bsz = 2
2861*da0073e9SAndroid Build Coastguard Worker        seq_length = 5
2862*da0073e9SAndroid Build Coastguard Worker        tgt_length = 3
2863*da0073e9SAndroid Build Coastguard Worker
2864*da0073e9SAndroid Build Coastguard Worker        for batch_first in (False, True):
2865*da0073e9SAndroid Build Coastguard Worker            def perm_fn(x):
2866*da0073e9SAndroid Build Coastguard Worker                return x.transpose(1, 0) if batch_first else x
2867*da0073e9SAndroid Build Coastguard Worker
2868*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
2869*da0073e9SAndroid Build Coastguard Worker                                               batch_first=batch_first)
2870*da0073e9SAndroid Build Coastguard Worker
2871*da0073e9SAndroid Build Coastguard Worker            # set constant weights of the model
2872*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(model.parameters()):
2873*da0073e9SAndroid Build Coastguard Worker                x = p.data
2874*da0073e9SAndroid Build Coastguard Worker                sz = x.view(-1).size(0)
2875*da0073e9SAndroid Build Coastguard Worker                shape = x.shape
2876*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(torch.arange(0, sz).float().view(shape))
2877*da0073e9SAndroid Build Coastguard Worker                p.data.copy_(x)
2878*da0073e9SAndroid Build Coastguard Worker
2879*da0073e9SAndroid Build Coastguard Worker            # deterministic input
2880*da0073e9SAndroid Build Coastguard Worker            decoder_input = torch.tensor([[[20., 30., 40., 50.]]])
2881*da0073e9SAndroid Build Coastguard Worker            memory_input = torch.tensor([[[60., 70., 80., 90.]]])
2882*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
2883*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor([[[2.314351, 0.094805, -0.671322, 0.101977]]])
2884*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2885*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
2886*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2887*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2888*da0073e9SAndroid Build Coastguard Worker
2889*da0073e9SAndroid Build Coastguard Worker            # deterministic input
2890*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
2891*da0073e9SAndroid Build Coastguard Worker                                                  [[11., 12., 13., 14.]]]))
2892*da0073e9SAndroid Build Coastguard Worker            memory_input = torch.tensor([[[1., 2., 3., 4.]]])
2893*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
2894*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2895*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.422245, 0.051716, -0.606338, -0.024756]],
2896*da0073e9SAndroid Build Coastguard Worker                                               [[2.422245, 0.051716, -0.606338, -0.024756]]]))
2897*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
2898*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2899*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2900*da0073e9SAndroid Build Coastguard Worker
2901*da0073e9SAndroid Build Coastguard Worker            # deterministic input
2902*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
2903*da0073e9SAndroid Build Coastguard Worker                                                  [[5., 6., 7., 8.]]]))
2904*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
2905*da0073e9SAndroid Build Coastguard Worker                                                 [[11., 12., 13., 14.]]]))
2906*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
2907*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.343536, 0.085561, -0.654954, 0.074991]],
2908*da0073e9SAndroid Build Coastguard Worker                                               [[2.343536, 0.085561, -0.654954, 0.074991]]]))
2909*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2910*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
2911*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2912*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker            # deterministic input
2915*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
2916*da0073e9SAndroid Build Coastguard Worker                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
2917*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
2918*da0073e9SAndroid Build Coastguard Worker                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
2919*da0073e9SAndroid Build Coastguard Worker                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
2920*da0073e9SAndroid Build Coastguard Worker                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]))
2921*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
2922*da0073e9SAndroid Build Coastguard Worker                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
2923*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
2924*da0073e9SAndroid Build Coastguard Worker                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
2925*da0073e9SAndroid Build Coastguard Worker                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
2926*da0073e9SAndroid Build Coastguard Worker                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
2927*da0073e9SAndroid Build Coastguard Worker                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
2928*da0073e9SAndroid Build Coastguard Worker                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
2929*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
2930*da0073e9SAndroid Build Coastguard Worker                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]))
2931*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
2932*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
2933*da0073e9SAndroid Build Coastguard Worker                                                [2.431935, 0.028907, -0.599809, -0.072488]],
2934*da0073e9SAndroid Build Coastguard Worker                                               [[2.428457, 0.027053, -0.602275, -0.073462],
2935*da0073e9SAndroid Build Coastguard Worker                                                [2.431970, 0.029387, -0.599789, -0.071621]],
2936*da0073e9SAndroid Build Coastguard Worker                                               [[2.431934, 0.028196, -0.599802, -0.073809],
2937*da0073e9SAndroid Build Coastguard Worker                                                [2.432306, 0.028858, -0.599542, -0.072846]]]))
2938*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2939*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
2940*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2941*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2942*da0073e9SAndroid Build Coastguard Worker
2943*da0073e9SAndroid Build Coastguard Worker            # key_padding_mask
2944*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = torch.zeros(2, 3) == 1
2945*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input, tgt_key_padding_mask=key_padding_mask)
2946*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
2947*da0073e9SAndroid Build Coastguard Worker                                                [2.431935, 0.028907, -0.599809, -0.072488]],
2948*da0073e9SAndroid Build Coastguard Worker                                               [[2.428457, 0.027053, -0.602275, -0.073462],
2949*da0073e9SAndroid Build Coastguard Worker                                                [2.431970, 0.029387, -0.599789, -0.071621]],
2950*da0073e9SAndroid Build Coastguard Worker                                               [[2.431934, 0.028196, -0.599802, -0.073809],
2951*da0073e9SAndroid Build Coastguard Worker                                                [2.432306, 0.028858, -0.599542, -0.072846]]]))
2952*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2953*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
2954*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2955*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2956*da0073e9SAndroid Build Coastguard Worker
2957*da0073e9SAndroid Build Coastguard Worker            # key_padding_mask
2958*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[0, 2] = 1
2959*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 1] = 1
2960*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 2] = 1
2961*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input, tgt_key_padding_mask=key_padding_mask)
2962*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430025, 0.027643, -0.601164, -0.073476],
2963*da0073e9SAndroid Build Coastguard Worker                                                [2.4323, 0.029375, -0.599553, -0.071881]],
2964*da0073e9SAndroid Build Coastguard Worker                                               [[2.428523, 0.026838, -0.602226, -0.07391],
2965*da0073e9SAndroid Build Coastguard Worker                                                [2.432634, 0.029842, -0.599318, -0.071253]],
2966*da0073e9SAndroid Build Coastguard Worker                                               [[2.432278, 0.028152, -0.599555, -0.074139],
2967*da0073e9SAndroid Build Coastguard Worker                                                [2.432659, 0.029244, -0.599294, -0.072382]]]))
2968*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2969*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
2970*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2971*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker            # memory_key_padding_mask
2974*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = torch.zeros(2, 5) == 1
2975*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input, memory_key_padding_mask=key_padding_mask)
2976*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
2977*da0073e9SAndroid Build Coastguard Worker                                                [2.431935, 0.028907, -0.599809, -0.072488]],
2978*da0073e9SAndroid Build Coastguard Worker                                               [[2.428457, 0.027053, -0.602275, -0.073462],
2979*da0073e9SAndroid Build Coastguard Worker                                                [2.431970, 0.029387, -0.599789, -0.071621]],
2980*da0073e9SAndroid Build Coastguard Worker                                               [[2.431934, 0.028196, -0.599802, -0.073809],
2981*da0073e9SAndroid Build Coastguard Worker                                                [2.432306, 0.028858, -0.599542, -0.072846]]]))
2982*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2983*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
2984*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
2985*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
2986*da0073e9SAndroid Build Coastguard Worker
2987*da0073e9SAndroid Build Coastguard Worker            # memory_key_padding_mask
2988*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[0, 4] = 1
2989*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 3] = 1
2990*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 4] = 1
2991*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input, memory_key_padding_mask=key_padding_mask)
2992*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.429757, 0.027358, -0.601351, -0.073816],
2993*da0073e9SAndroid Build Coastguard Worker                                                [2.432692, 0.028583, -0.599263, -0.073634]],
2994*da0073e9SAndroid Build Coastguard Worker                                               [[2.428247, 0.02662, -0.602419, -0.074123],
2995*da0073e9SAndroid Build Coastguard Worker                                                [2.432657, 0.029055, -0.599293, -0.072732]],
2996*da0073e9SAndroid Build Coastguard Worker                                               [[2.431515, 0.027687, -0.600096, -0.074459],
2997*da0073e9SAndroid Build Coastguard Worker                                                [2.433075, 0.028543, -0.598987, -0.073985]]]))
2998*da0073e9SAndroid Build Coastguard Worker            result = result.detach().numpy()
2999*da0073e9SAndroid Build Coastguard Worker            ref_output = ref_output.detach().numpy()
3000*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3001*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(result, ref_output, atol=1e-5)
3002*da0073e9SAndroid Build Coastguard Worker
3003*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
3004*da0073e9SAndroid Build Coastguard Worker    def test_transformerdecoderlayer_gelu(self):
3005*da0073e9SAndroid Build Coastguard Worker        # this is a deterministic test for TransformerDecoderLayer with gelu activation
3006*da0073e9SAndroid Build Coastguard Worker        d_model = 4
3007*da0073e9SAndroid Build Coastguard Worker        nhead = 2
3008*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 16
3009*da0073e9SAndroid Build Coastguard Worker        dropout = 0.0
3010*da0073e9SAndroid Build Coastguard Worker        bsz = 2
3011*da0073e9SAndroid Build Coastguard Worker        seq_length = 5
3012*da0073e9SAndroid Build Coastguard Worker        tgt_length = 3
3013*da0073e9SAndroid Build Coastguard Worker
3014*da0073e9SAndroid Build Coastguard Worker        for activation, batch_first in product(('gelu', F.gelu, nn.GELU()), (True, False)):
3015*da0073e9SAndroid Build Coastguard Worker            def perm_fn(x):
3016*da0073e9SAndroid Build Coastguard Worker                return x.transpose(1, 0) if batch_first else x
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
3019*da0073e9SAndroid Build Coastguard Worker                                               activation, batch_first=batch_first)
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker            # set constant weights of the model
3022*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(model.parameters()):
3023*da0073e9SAndroid Build Coastguard Worker                x = p.data
3024*da0073e9SAndroid Build Coastguard Worker                sz = x.view(-1).size(0)
3025*da0073e9SAndroid Build Coastguard Worker                shape = x.shape
3026*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(torch.arange(0, sz).float().view(shape))
3027*da0073e9SAndroid Build Coastguard Worker                p.data.copy_(x)
3028*da0073e9SAndroid Build Coastguard Worker
3029*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3030*da0073e9SAndroid Build Coastguard Worker            decoder_input = torch.tensor([[[20., 30., 40., 50.]]])
3031*da0073e9SAndroid Build Coastguard Worker            memory_input = torch.tensor([[[60., 70., 80., 90.]]])
3032*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3033*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]])
3034*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3037*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3038*da0073e9SAndroid Build Coastguard Worker                                                  [[11., 12., 13., 14.]]]))
3039*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]]))
3040*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3041*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]],
3042*da0073e9SAndroid Build Coastguard Worker                                               [[2.415448, 0.054389, -0.610932, -0.0156613]]]))
3043*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3044*da0073e9SAndroid Build Coastguard Worker
3045*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3046*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
3047*da0073e9SAndroid Build Coastguard Worker                                                  [[5., 6., 7., 8.]]]))
3048*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3049*da0073e9SAndroid Build Coastguard Worker                                                 [[11., 12., 13., 14.]]]))
3050*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3051*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]],
3052*da0073e9SAndroid Build Coastguard Worker                                               [[2.338531, 0.087709, -0.65776, 0.080646]]]))
3053*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3054*da0073e9SAndroid Build Coastguard Worker
3055*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3056*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3057*da0073e9SAndroid Build Coastguard Worker                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3058*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3059*da0073e9SAndroid Build Coastguard Worker                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3060*da0073e9SAndroid Build Coastguard Worker                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3061*da0073e9SAndroid Build Coastguard Worker                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]))
3062*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3063*da0073e9SAndroid Build Coastguard Worker                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3064*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3065*da0073e9SAndroid Build Coastguard Worker                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3066*da0073e9SAndroid Build Coastguard Worker                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3067*da0073e9SAndroid Build Coastguard Worker                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3068*da0073e9SAndroid Build Coastguard Worker                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3069*da0073e9SAndroid Build Coastguard Worker                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3070*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3071*da0073e9SAndroid Build Coastguard Worker                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]))
3072*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3073*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.42049104, 0.03443088, -0.60793706, -0.05436271],
3074*da0073e9SAndroid Build Coastguard Worker                                                [2.42210631, 0.03546578, -0.60679895, -0.05357488]],
3075*da0073e9SAndroid Build Coastguard Worker                                               [[2.41907674, 0.0336104, -0.60892977, -0.05490462],
3076*da0073e9SAndroid Build Coastguard Worker                                                [2.42216881, 0.03586554, -0.6067524, -0.05289126]],
3077*da0073e9SAndroid Build Coastguard Worker                                               [[2.42205716, 0.03488046, -0.60683681, -0.05460596],
3078*da0073e9SAndroid Build Coastguard Worker                                                [2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
3079*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
3080*da0073e9SAndroid Build Coastguard Worker
3081*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm(msg='Large numerical errors')
3082*da0073e9SAndroid Build Coastguard Worker    def test_transformerdecoder(self):
3083*da0073e9SAndroid Build Coastguard Worker        def get_a_test_layer(use_cuda, activation, batch_first=False):
3084*da0073e9SAndroid Build Coastguard Worker            d_model = 4
3085*da0073e9SAndroid Build Coastguard Worker            nhead = 2
3086*da0073e9SAndroid Build Coastguard Worker            dim_feedforward = 16
3087*da0073e9SAndroid Build Coastguard Worker            dropout = 0.0
3088*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda" if use_cuda else "cpu")
3089*da0073e9SAndroid Build Coastguard Worker
3090*da0073e9SAndroid Build Coastguard Worker            layer = nn.TransformerDecoderLayer(
3091*da0073e9SAndroid Build Coastguard Worker                d_model,
3092*da0073e9SAndroid Build Coastguard Worker                nhead,
3093*da0073e9SAndroid Build Coastguard Worker                dim_feedforward=dim_feedforward,
3094*da0073e9SAndroid Build Coastguard Worker                dropout=dropout,
3095*da0073e9SAndroid Build Coastguard Worker                activation=activation,
3096*da0073e9SAndroid Build Coastguard Worker                batch_first=batch_first).to(device)
3097*da0073e9SAndroid Build Coastguard Worker
3098*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
3099*da0073e9SAndroid Build Coastguard Worker                # set constant weights of the model
3100*da0073e9SAndroid Build Coastguard Worker                for idx, p in enumerate(layer.parameters()):
3101*da0073e9SAndroid Build Coastguard Worker                    x = p.data
3102*da0073e9SAndroid Build Coastguard Worker                    sz = x.view(-1).size(0)
3103*da0073e9SAndroid Build Coastguard Worker                    shape = x.shape
3104*da0073e9SAndroid Build Coastguard Worker                    x = torch.cos(torch.arange(0, sz).float().view(shape))
3105*da0073e9SAndroid Build Coastguard Worker                    p.data.copy_(x)
3106*da0073e9SAndroid Build Coastguard Worker
3107*da0073e9SAndroid Build Coastguard Worker            return layer
3108*da0073e9SAndroid Build Coastguard Worker
3109*da0073e9SAndroid Build Coastguard Worker        # this is a deterministic test for TransformerDecoder
3110*da0073e9SAndroid Build Coastguard Worker        for batch_first in (False, True):
3111*da0073e9SAndroid Build Coastguard Worker            def perm_fn(x):
3112*da0073e9SAndroid Build Coastguard Worker                return x.transpose(1, 0) if batch_first else x
3113*da0073e9SAndroid Build Coastguard Worker            activation = F.relu
3114*da0073e9SAndroid Build Coastguard Worker            use_cuda = torch.cuda.is_available()
3115*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda" if use_cuda else "cpu")
3116*da0073e9SAndroid Build Coastguard Worker
3117*da0073e9SAndroid Build Coastguard Worker            decoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation,
3118*da0073e9SAndroid Build Coastguard Worker                                             batch_first=batch_first)
3119*da0073e9SAndroid Build Coastguard Worker
3120*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoder(decoder_layer, 1).to(device)
3121*da0073e9SAndroid Build Coastguard Worker
3122*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3123*da0073e9SAndroid Build Coastguard Worker            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3124*da0073e9SAndroid Build Coastguard Worker            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3125*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3126*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor(
3127*da0073e9SAndroid Build Coastguard Worker                [[[2.314351, 0.094805, -0.671322, 0.101977]]]).to(device)
3128*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3129*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3130*da0073e9SAndroid Build Coastguard Worker
3131*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3132*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3133*da0073e9SAndroid Build Coastguard Worker                                                  [[11., 12., 13., 14.]]])).to(device)
3134*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]])).to(device)
3135*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3136*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.422245, 0.051716, -0.606338, -0.024756]],
3137*da0073e9SAndroid Build Coastguard Worker                                               [[2.422245, 0.051716, -0.606338, -0.024756]]]
3138*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3139*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3140*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3141*da0073e9SAndroid Build Coastguard Worker
3142*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3143*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
3144*da0073e9SAndroid Build Coastguard Worker                                                  [[5., 6., 7., 8.]]])).to(device)
3145*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3146*da0073e9SAndroid Build Coastguard Worker                                                 [[11., 12., 13., 14.]]])).to(device)
3147*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3148*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.343536, 0.085561, -0.654954, 0.074991]],
3149*da0073e9SAndroid Build Coastguard Worker                                               [[2.343536, 0.085561, -0.654954, 0.074991]]]
3150*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3151*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3152*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3153*da0073e9SAndroid Build Coastguard Worker
3154*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3155*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3156*da0073e9SAndroid Build Coastguard Worker                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3157*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3158*da0073e9SAndroid Build Coastguard Worker                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3159*da0073e9SAndroid Build Coastguard Worker                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3160*da0073e9SAndroid Build Coastguard Worker                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3161*da0073e9SAndroid Build Coastguard Worker                                                 )).to(device)
3162*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3163*da0073e9SAndroid Build Coastguard Worker                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3164*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3165*da0073e9SAndroid Build Coastguard Worker                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3166*da0073e9SAndroid Build Coastguard Worker                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3167*da0073e9SAndroid Build Coastguard Worker                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3168*da0073e9SAndroid Build Coastguard Worker                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3169*da0073e9SAndroid Build Coastguard Worker                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3170*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3171*da0073e9SAndroid Build Coastguard Worker                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3172*da0073e9SAndroid Build Coastguard Worker                                                )).to(device)
3173*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3174*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
3175*da0073e9SAndroid Build Coastguard Worker                                                [2.431935, 0.028907, -0.599809, -0.072488]],
3176*da0073e9SAndroid Build Coastguard Worker                                               [[2.428457, 0.027053, -0.602275, -0.073462],
3177*da0073e9SAndroid Build Coastguard Worker                                                [2.431970, 0.029387, -0.599789, -0.071621]],
3178*da0073e9SAndroid Build Coastguard Worker                                               [[2.431934, 0.028196, -0.599802, -0.073809],
3179*da0073e9SAndroid Build Coastguard Worker                                                [2.432306, 0.028858, -0.599542, -0.072846]]]
3180*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3181*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3182*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3183*da0073e9SAndroid Build Coastguard Worker
3184*da0073e9SAndroid Build Coastguard Worker            # key_padding_mask
3185*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = torch.zeros(2, 3).to(device) == 1
3186*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input,
3187*da0073e9SAndroid Build Coastguard Worker                           tgt_key_padding_mask=key_padding_mask)
3188*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
3189*da0073e9SAndroid Build Coastguard Worker                                                [2.431935, 0.028907, -0.599809, -0.072488]],
3190*da0073e9SAndroid Build Coastguard Worker                                               [[2.428457, 0.027053, -0.602275, -0.073462],
3191*da0073e9SAndroid Build Coastguard Worker                                                [2.431970, 0.029387, -0.599789, -0.071621]],
3192*da0073e9SAndroid Build Coastguard Worker                                               [[2.431934, 0.028196, -0.599802, -0.073809],
3193*da0073e9SAndroid Build Coastguard Worker                                                [2.432306, 0.028858, -0.599542, -0.072846]]]
3194*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3195*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3196*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3197*da0073e9SAndroid Build Coastguard Worker
3198*da0073e9SAndroid Build Coastguard Worker            # key_padding_mask
3199*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[0, 2] = 1
3200*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 1] = 1
3201*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 2] = 1
3202*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input,
3203*da0073e9SAndroid Build Coastguard Worker                           tgt_key_padding_mask=key_padding_mask)
3204*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430025, 0.027643, -0.601164, -0.073476],
3205*da0073e9SAndroid Build Coastguard Worker                                                [2.4323, 0.029375, -0.599553, -0.071881]],
3206*da0073e9SAndroid Build Coastguard Worker                                               [[2.428523, 0.026838, -0.602226, -0.07391],
3207*da0073e9SAndroid Build Coastguard Worker                                                [2.432634, 0.029842, -0.599318, -0.071253]],
3208*da0073e9SAndroid Build Coastguard Worker                                               [[2.432278, 0.028152, -0.599555, -0.074139],
3209*da0073e9SAndroid Build Coastguard Worker                                                [2.432659, 0.029244, -0.599294, -0.072382]]]
3210*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3211*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3212*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3213*da0073e9SAndroid Build Coastguard Worker
3214*da0073e9SAndroid Build Coastguard Worker            # memory_key_padding_mask
3215*da0073e9SAndroid Build Coastguard Worker            key_padding_mask = torch.zeros(2, 5).to(device) == 1
3216*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input,
3217*da0073e9SAndroid Build Coastguard Worker                           memory_key_padding_mask=key_padding_mask)
3218*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.430065, 0.027862, -0.601136, -0.073096],
3219*da0073e9SAndroid Build Coastguard Worker                                                [2.431935, 0.028907, -0.599809, -0.072488]],
3220*da0073e9SAndroid Build Coastguard Worker                                               [[2.428457, 0.027053, -0.602275, -0.073462],
3221*da0073e9SAndroid Build Coastguard Worker                                                [2.431970, 0.029387, -0.599789, -0.071621]],
3222*da0073e9SAndroid Build Coastguard Worker                                               [[2.431934, 0.028196, -0.599802, -0.073809],
3223*da0073e9SAndroid Build Coastguard Worker                                                [2.432306, 0.028858, -0.599542, -0.072846]]]
3224*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3226*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Worker            # memory_key_padding_mask
3229*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[0, 4] = 1
3230*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 3] = 1
3231*da0073e9SAndroid Build Coastguard Worker            key_padding_mask[1, 4] = 1
3232*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input,
3233*da0073e9SAndroid Build Coastguard Worker                           memory_input,
3234*da0073e9SAndroid Build Coastguard Worker                           memory_key_padding_mask=key_padding_mask)
3235*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.429757, 0.027358, -0.601351, -0.073816],
3236*da0073e9SAndroid Build Coastguard Worker                                                [2.432692, 0.028583, -0.599263, -0.073634]],
3237*da0073e9SAndroid Build Coastguard Worker                                               [[2.428247, 0.02662, -0.602419, -0.074123],
3238*da0073e9SAndroid Build Coastguard Worker                                                [2.432657, 0.029055, -0.599293, -0.072732]],
3239*da0073e9SAndroid Build Coastguard Worker                                               [[2.431515, 0.027687, -0.600096, -0.074459],
3240*da0073e9SAndroid Build Coastguard Worker                                                [2.433075, 0.028543, -0.598987, -0.073985]]]
3241*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3242*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3243*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3244*da0073e9SAndroid Build Coastguard Worker
3245*da0073e9SAndroid Build Coastguard Worker            # multiple layers no norm
3246*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoder(decoder_layer, 2).to(device)
3247*da0073e9SAndroid Build Coastguard Worker
3248*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3249*da0073e9SAndroid Build Coastguard Worker            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3250*da0073e9SAndroid Build Coastguard Worker            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3251*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3252*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor(
3253*da0073e9SAndroid Build Coastguard Worker                [[[2.31316, 0.0950293, -0.671995, 0.102802]]]).to(device)
3254*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3255*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3256*da0073e9SAndroid Build Coastguard Worker
3257*da0073e9SAndroid Build Coastguard Worker            # multiple layers no norm
3258*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoder(decoder_layer, 6).to(device)
3259*da0073e9SAndroid Build Coastguard Worker
3260*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3261*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3262*da0073e9SAndroid Build Coastguard Worker                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3263*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3264*da0073e9SAndroid Build Coastguard Worker                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3265*da0073e9SAndroid Build Coastguard Worker                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3266*da0073e9SAndroid Build Coastguard Worker                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3267*da0073e9SAndroid Build Coastguard Worker                                                 )).to(device)
3268*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3269*da0073e9SAndroid Build Coastguard Worker                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3270*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3271*da0073e9SAndroid Build Coastguard Worker                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3272*da0073e9SAndroid Build Coastguard Worker                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3273*da0073e9SAndroid Build Coastguard Worker                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3274*da0073e9SAndroid Build Coastguard Worker                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3275*da0073e9SAndroid Build Coastguard Worker                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3276*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3277*da0073e9SAndroid Build Coastguard Worker                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3278*da0073e9SAndroid Build Coastguard Worker                                                )).to(device)
3279*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3280*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.42794, 0.026164, -0.60263, -0.0747591],
3281*da0073e9SAndroid Build Coastguard Worker                                                [2.43113, 0.0279516, -0.600376, -0.0736896]],
3282*da0073e9SAndroid Build Coastguard Worker                                               [[2.42794, 0.026164, -0.60263, -0.0747591],
3283*da0073e9SAndroid Build Coastguard Worker                                                [2.43113, 0.0279516, -0.600376, -0.0736896]],
3284*da0073e9SAndroid Build Coastguard Worker                                               [[2.42794, 0.026164, -0.60263, -0.0747591],
3285*da0073e9SAndroid Build Coastguard Worker                                                [2.43113, 0.0279516, -0.600376, -0.0736896]]]
3286*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3288*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Worker            # multiple layers with norm
3291*da0073e9SAndroid Build Coastguard Worker            # d_model = 4
3292*da0073e9SAndroid Build Coastguard Worker            norm = nn.LayerNorm(4)
3293*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoder(decoder_layer, 2, norm=norm).to(device)
3294*da0073e9SAndroid Build Coastguard Worker
3295*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3296*da0073e9SAndroid Build Coastguard Worker            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3297*da0073e9SAndroid Build Coastguard Worker            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3298*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3299*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor(
3300*da0073e9SAndroid Build Coastguard Worker                [[[1.66166, -0.326986, -1.01466, -0.320017]]]).to(device)
3301*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3302*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3303*da0073e9SAndroid Build Coastguard Worker
3304*da0073e9SAndroid Build Coastguard Worker            # multiple layers with norm
3305*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoder(decoder_layer, 6, norm=norm).to(device)
3306*da0073e9SAndroid Build Coastguard Worker
3307*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3308*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3309*da0073e9SAndroid Build Coastguard Worker                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3310*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3311*da0073e9SAndroid Build Coastguard Worker                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3312*da0073e9SAndroid Build Coastguard Worker                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3313*da0073e9SAndroid Build Coastguard Worker                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3314*da0073e9SAndroid Build Coastguard Worker                                                 )).to(device)
3315*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3316*da0073e9SAndroid Build Coastguard Worker                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3317*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3318*da0073e9SAndroid Build Coastguard Worker                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3319*da0073e9SAndroid Build Coastguard Worker                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3320*da0073e9SAndroid Build Coastguard Worker                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3321*da0073e9SAndroid Build Coastguard Worker                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3322*da0073e9SAndroid Build Coastguard Worker                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3323*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3324*da0073e9SAndroid Build Coastguard Worker                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3325*da0073e9SAndroid Build Coastguard Worker                                                )).to(device)
3326*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3327*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[1.69559, -0.357291, -0.894741, -0.443553],
3328*da0073e9SAndroid Build Coastguard Worker                                                [1.69571, -0.357363, -0.894154, -0.444196]],
3329*da0073e9SAndroid Build Coastguard Worker                                               [[1.69559, -0.357291, -0.894741, -0.443553],
3330*da0073e9SAndroid Build Coastguard Worker                                                [1.69571, -0.357363, -0.894154, -0.444196]],
3331*da0073e9SAndroid Build Coastguard Worker                                               [[1.69559, -0.357291, -0.894741, -0.443553],
3332*da0073e9SAndroid Build Coastguard Worker                                                [1.69571, -0.357363, -0.894154, -0.444196]]]
3333*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3334*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3335*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3336*da0073e9SAndroid Build Coastguard Worker
3337*da0073e9SAndroid Build Coastguard Worker            # gelu activation test cases
3338*da0073e9SAndroid Build Coastguard Worker            activation = "gelu"
3339*da0073e9SAndroid Build Coastguard Worker            use_cuda = torch.cuda.is_available()
3340*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cuda" if use_cuda else "cpu")
3341*da0073e9SAndroid Build Coastguard Worker
3342*da0073e9SAndroid Build Coastguard Worker            decoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation,
3343*da0073e9SAndroid Build Coastguard Worker                                             batch_first=batch_first)
3344*da0073e9SAndroid Build Coastguard Worker
3345*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerDecoder(decoder_layer, 1).to(device)
3346*da0073e9SAndroid Build Coastguard Worker
3347*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3348*da0073e9SAndroid Build Coastguard Worker            decoder_input = torch.tensor([[[20., 30., 40., 50.]]]).to(device)
3349*da0073e9SAndroid Build Coastguard Worker            memory_input = torch.tensor([[[60., 70., 80., 90.]]]).to(device)
3350*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3351*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor([[[2.306435, 0.095946, -0.675796, 0.10687]]]).to(device)
3352*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3353*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-3)
3354*da0073e9SAndroid Build Coastguard Worker
3355*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3356*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3357*da0073e9SAndroid Build Coastguard Worker                                                  [[11., 12., 13., 14.]]])).to(device)
3358*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]]])).to(device)
3359*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3360*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.415448, 0.054389, -0.610932, -0.0156613]],
3361*da0073e9SAndroid Build Coastguard Worker                                               [[2.415448, 0.054389, -0.610932, -0.0156613]]])).to(device)
3362*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3363*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3364*da0073e9SAndroid Build Coastguard Worker
3365*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3366*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
3367*da0073e9SAndroid Build Coastguard Worker                                                  [[5., 6., 7., 8.]]])).to(device)
3368*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[9., 10., 11., 12.]],
3369*da0073e9SAndroid Build Coastguard Worker                                                 [[11., 12., 13., 14.]]])).to(device)
3370*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3371*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.338531, 0.087709, -0.65776, 0.080646]],
3372*da0073e9SAndroid Build Coastguard Worker                                               [[2.338531, 0.087709, -0.65776, 0.080646]]])).to(device)
3373*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3374*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-4)
3375*da0073e9SAndroid Build Coastguard Worker
3376*da0073e9SAndroid Build Coastguard Worker            # deterministic input
3377*da0073e9SAndroid Build Coastguard Worker            decoder_input = perm_fn(torch.tensor([[[0.4517, 0.6793, 0.5313, 0.0034],
3378*da0073e9SAndroid Build Coastguard Worker                                                   [0.2678, 0.3677, 0.4459, 0.7166]],
3379*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8100, 0.3716, 0.4096, 0.1976],
3380*da0073e9SAndroid Build Coastguard Worker                                                   [0.6958, 0.8844, 0.6081, 0.8315]],
3381*da0073e9SAndroid Build Coastguard Worker                                                  [[0.0494, 0.9343, 0.5955, 0.3830],
3382*da0073e9SAndroid Build Coastguard Worker                                                   [0.5404, 0.3464, 0.9378, 0.6200]]]
3383*da0073e9SAndroid Build Coastguard Worker                                                 )).to(device)
3384*da0073e9SAndroid Build Coastguard Worker            memory_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
3385*da0073e9SAndroid Build Coastguard Worker                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
3386*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8335, 0.2799, 0.5031, 0.2947],
3387*da0073e9SAndroid Build Coastguard Worker                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
3388*da0073e9SAndroid Build Coastguard Worker                                                 [[0.6333, 0.9344, 0.1376, 0.9938],
3389*da0073e9SAndroid Build Coastguard Worker                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
3390*da0073e9SAndroid Build Coastguard Worker                                                 [[0.9897, 0.6915, 0.3154, 0.1733],
3391*da0073e9SAndroid Build Coastguard Worker                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
3392*da0073e9SAndroid Build Coastguard Worker                                                 [[0.8117, 0.2366, 0.4838, 0.7881],
3393*da0073e9SAndroid Build Coastguard Worker                                                  [0.3718, 0.4945, 0.9511, 0.0864]]]
3394*da0073e9SAndroid Build Coastguard Worker                                                )).to(device)
3395*da0073e9SAndroid Build Coastguard Worker            result = model(decoder_input, memory_input)
3396*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.42049104, 0.03443088, -0.60793706, -0.05436271],
3397*da0073e9SAndroid Build Coastguard Worker                                                [2.42210631, 0.03546578, -0.60679895, -0.05357488]],
3398*da0073e9SAndroid Build Coastguard Worker                                               [[2.41907674, 0.0336104, -0.60892977, -0.05490462],
3399*da0073e9SAndroid Build Coastguard Worker                                                [2.42216881, 0.03586554, -0.6067524, -0.05289126]],
3400*da0073e9SAndroid Build Coastguard Worker                                               [[2.42205716, 0.03488046, -0.60683681, -0.05460596],
3401*da0073e9SAndroid Build Coastguard Worker                                                [2.42240309, 0.0354595, -0.60659063, -0.05378816]]]
3402*da0073e9SAndroid Build Coastguard Worker                                              )).to(device)
3403*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
3404*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
3405*da0073e9SAndroid Build Coastguard Worker
3406*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not (TEST_CUDNN and TEST_MULTIGPU), 'CUDNN or multi-gpu not available')
3407*da0073e9SAndroid Build Coastguard Worker    def test_cudnn_rnn_dropout_states_device(self):
3408*da0073e9SAndroid Build Coastguard Worker        rnn = nn.RNN(10, 20, num_layers=2, dropout=.5)
3409*da0073e9SAndroid Build Coastguard Worker        device = 1
3410*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5, 4, 10).cuda(device)
3411*da0073e9SAndroid Build Coastguard Worker        rnn.cuda(device)
3412*da0073e9SAndroid Build Coastguard Worker        hx = torch.randn(2, 4, 20).cuda(device)
3413*da0073e9SAndroid Build Coastguard Worker        output = rnn(input, hx)
3414*da0073e9SAndroid Build Coastguard Worker
3415*da0073e9SAndroid Build Coastguard Worker    def test_cudnn_forward_exception(self):
3416*da0073e9SAndroid Build Coastguard Worker        rnns = [
3417*da0073e9SAndroid Build Coastguard Worker            (nn.LSTM(10, 20, batch_first=True), (torch.zeros(1, 2, 19), torch.zeros(1, 2, 19))),
3418*da0073e9SAndroid Build Coastguard Worker            (nn.LSTM(10, 20, batch_first=True, proj_size=10), (torch.zeros(1, 2, 19), torch.zeros(1, 2, 19))),
3419*da0073e9SAndroid Build Coastguard Worker            (nn.GRU(10, 20, batch_first=True), torch.zeros(1, 2, 19)),
3420*da0073e9SAndroid Build Coastguard Worker            (nn.RNN(10, 20, batch_first=True), torch.zeros(1, 2, 19)),
3421*da0073e9SAndroid Build Coastguard Worker        ]
3422*da0073e9SAndroid Build Coastguard Worker        x_wrong = torch.randn(2, 3, 3)
3423*da0073e9SAndroid Build Coastguard Worker        x_right = torch.randn(2, 3, 10)
3424*da0073e9SAndroid Build Coastguard Worker        for rnn, hidden in rnns:
3425*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, "Expected hidden.*size.*got", rnn, x_right, hidden)
3426*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
3427*da0073e9SAndroid Build Coastguard Worker
3428*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3429*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
3430*da0073e9SAndroid Build Coastguard Worker    def test_cudnn_weight_format(self):
3431*da0073e9SAndroid Build Coastguard Worker        rnns = [
3432*da0073e9SAndroid Build Coastguard Worker            nn.LSTM(10, 20, batch_first=True),
3433*da0073e9SAndroid Build Coastguard Worker            nn.LSTM(10, 20, batch_first=True, proj_size=10),
3434*da0073e9SAndroid Build Coastguard Worker            nn.GRU(10, 20, batch_first=True),
3435*da0073e9SAndroid Build Coastguard Worker            nn.RNN(10, 20, batch_first=True)
3436*da0073e9SAndroid Build Coastguard Worker        ]
3437*da0073e9SAndroid Build Coastguard Worker        first_warn = True
3438*da0073e9SAndroid Build Coastguard Worker        for rnn in rnns:
3439*da0073e9SAndroid Build Coastguard Worker            rnn.cuda()
3440*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
3441*da0073e9SAndroid Build Coastguard Worker            hx = torch.randn(1, 5, 20, requires_grad=True, device="cuda")
3442*da0073e9SAndroid Build Coastguard Worker            all_vars = [input, hx] + list(rnn.parameters())
3443*da0073e9SAndroid Build Coastguard Worker            if isinstance(rnn, nn.LSTM):
3444*da0073e9SAndroid Build Coastguard Worker                # LSTM with projections has different hx size
3445*da0073e9SAndroid Build Coastguard Worker                if rnn.proj_size > 0:
3446*da0073e9SAndroid Build Coastguard Worker                    hx = torch.randn(1, 5, 10, requires_grad=True, device="cuda")
3447*da0073e9SAndroid Build Coastguard Worker                    all_vars[1] = hx
3448*da0073e9SAndroid Build Coastguard Worker                cx = torch.randn(1, 5, 20, requires_grad=True, device="cuda")
3449*da0073e9SAndroid Build Coastguard Worker                all_vars[2:2] = [cx]
3450*da0073e9SAndroid Build Coastguard Worker                hx = (hx, cx)
3451*da0073e9SAndroid Build Coastguard Worker
3452*da0073e9SAndroid Build Coastguard Worker            output = rnn(input, hx)
3453*da0073e9SAndroid Build Coastguard Worker            output[0].sum().backward()
3454*da0073e9SAndroid Build Coastguard Worker            grads = [v.grad.data.clone() for v in all_vars]
3455*da0073e9SAndroid Build Coastguard Worker            for v in all_vars:
3456*da0073e9SAndroid Build Coastguard Worker                v.grad.data.zero_()
3457*da0073e9SAndroid Build Coastguard Worker
3458*da0073e9SAndroid Build Coastguard Worker            # Weights will no longer view onto the same chunk of memory
3459*da0073e9SAndroid Build Coastguard Worker            weight = all_vars[4]
3460*da0073e9SAndroid Build Coastguard Worker            weight_data = weight.data.clone()
3461*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
3462*da0073e9SAndroid Build Coastguard Worker                weight.set_(weight_data)
3463*da0073e9SAndroid Build Coastguard Worker
3464*da0073e9SAndroid Build Coastguard Worker            for _ in range(2):
3465*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as w:
3466*da0073e9SAndroid Build Coastguard Worker                    output_noncontig = rnn(input, hx)
3467*da0073e9SAndroid Build Coastguard Worker                if first_warn:
3468*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(w), 1)
3469*da0073e9SAndroid Build Coastguard Worker                    self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
3470*da0073e9SAndroid Build Coastguard Worker                    first_warn = False
3471*da0073e9SAndroid Build Coastguard Worker                    warnings.resetwarnings()
3472*da0073e9SAndroid Build Coastguard Worker                output_noncontig[0].sum().backward()
3473*da0073e9SAndroid Build Coastguard Worker                grads_noncontig = [v.grad.data.clone() for v in all_vars]
3474*da0073e9SAndroid Build Coastguard Worker                for v in all_vars:
3475*da0073e9SAndroid Build Coastguard Worker                    v.grad.data.zero_()
3476*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, output_noncontig)
3477*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grads_noncontig, grads)
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker            # Make sure these still share storage
3480*da0073e9SAndroid Build Coastguard Worker            weight_data[:] = 4
3481*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(weight_data, all_vars[4].data)
3482*da0073e9SAndroid Build Coastguard Worker
3483*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3484*da0073e9SAndroid Build Coastguard Worker    def test_cudnn_weight_tying(self):
3485*da0073e9SAndroid Build Coastguard Worker        rnns = [
3486*da0073e9SAndroid Build Coastguard Worker            nn.LSTM(10, 20, batch_first=True, bidirectional=True),
3487*da0073e9SAndroid Build Coastguard Worker            nn.LSTM(10, 20, batch_first=True, bidirectional=True, proj_size=10),
3488*da0073e9SAndroid Build Coastguard Worker            nn.GRU(10, 20, batch_first=True, bidirectional=True),
3489*da0073e9SAndroid Build Coastguard Worker            nn.RNN(10, 20, batch_first=True, bidirectional=True)
3490*da0073e9SAndroid Build Coastguard Worker        ]
3491*da0073e9SAndroid Build Coastguard Worker        for rnn in rnns:
3492*da0073e9SAndroid Build Coastguard Worker            rnn.bias_ih_l0_reverse = rnn.bias_ih_l0
3493*da0073e9SAndroid Build Coastguard Worker            rnn.cuda()
3494*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
3495*da0073e9SAndroid Build Coastguard Worker            hx = torch.randn(2, 5, 20, requires_grad=True, device="cuda")
3496*da0073e9SAndroid Build Coastguard Worker            all_vars = [input, hx] + list(rnn.parameters())
3497*da0073e9SAndroid Build Coastguard Worker            opt = torch.optim.SGD(rnn.parameters(), lr=0.1)
3498*da0073e9SAndroid Build Coastguard Worker            opt.zero_grad()
3499*da0073e9SAndroid Build Coastguard Worker            if isinstance(rnn, nn.LSTM):
3500*da0073e9SAndroid Build Coastguard Worker                # LSTM with projections has different hx size
3501*da0073e9SAndroid Build Coastguard Worker                if rnn.proj_size > 0:
3502*da0073e9SAndroid Build Coastguard Worker                    hx = torch.randn(2, 5, 10, requires_grad=True, device="cuda")
3503*da0073e9SAndroid Build Coastguard Worker                    all_vars[1] = hx
3504*da0073e9SAndroid Build Coastguard Worker                cx = torch.randn(2, 5, 20, requires_grad=True, device="cuda")
3505*da0073e9SAndroid Build Coastguard Worker                all_vars[2:2] = [cx]
3506*da0073e9SAndroid Build Coastguard Worker                hx = (hx, cx)
3507*da0073e9SAndroid Build Coastguard Worker
3508*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
3509*da0073e9SAndroid Build Coastguard Worker                output = rnn(input, hx)
3510*da0073e9SAndroid Build Coastguard Worker            output[0].sum().backward()
3511*da0073e9SAndroid Build Coastguard Worker
3512*da0073e9SAndroid Build Coastguard Worker            opt.step()
3513*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
3514*da0073e9SAndroid Build Coastguard Worker                output_cuda = rnn(input, hx)
3515*da0073e9SAndroid Build Coastguard Worker            rnn.cpu()
3516*da0073e9SAndroid Build Coastguard Worker            hx = (hx[0].cpu(), hx[1].cpu()) if isinstance(rnn, nn.LSTM) else hx.cpu()
3517*da0073e9SAndroid Build Coastguard Worker            output_cpu = rnn(input.cpu(), hx)
3518*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cuda, output_cpu)
3519*da0073e9SAndroid Build Coastguard Worker
3520*da0073e9SAndroid Build Coastguard Worker
3521*da0073e9SAndroid Build Coastguard Worker    def test_transformer_args_check(self):
3522*da0073e9SAndroid Build Coastguard Worker        model_name = 'Transformer'
3523*da0073e9SAndroid Build Coastguard Worker        d_model = 128
3524*da0073e9SAndroid Build Coastguard Worker        nhead = 4
3525*da0073e9SAndroid Build Coastguard Worker        num_encoder_layers = 2
3526*da0073e9SAndroid Build Coastguard Worker        num_decoder_layers = 3
3527*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 65
3528*da0073e9SAndroid Build Coastguard Worker        dropout = 0.3
3529*da0073e9SAndroid Build Coastguard Worker        bsz = 3
3530*da0073e9SAndroid Build Coastguard Worker        seq_len = 35
3531*da0073e9SAndroid Build Coastguard Worker        tgt_len = 15
3532*da0073e9SAndroid Build Coastguard Worker        activations = [F.relu, F.gelu]
3533*da0073e9SAndroid Build Coastguard Worker
3534*da0073e9SAndroid Build Coastguard Worker        wrong_bsz = 7
3535*da0073e9SAndroid Build Coastguard Worker        wrong_d_model = 63
3536*da0073e9SAndroid Build Coastguard Worker        wrong_nhead = 5
3537*da0073e9SAndroid Build Coastguard Worker        wrong_activation = "abc"
3538*da0073e9SAndroid Build Coastguard Worker
3539*da0073e9SAndroid Build Coastguard Worker        def test(encoder_input_shape, decoder_input_shape,
3540*da0073e9SAndroid Build Coastguard Worker                 src_mask_len=None, tgt_mask_len=None, memory_mask_size=None,
3541*da0073e9SAndroid Build Coastguard Worker                 src_key_padding_mask_size=None, tgt_key_padding_mask_size=None,
3542*da0073e9SAndroid Build Coastguard Worker                 memory_key_padding_mask_size=None,
3543*da0073e9SAndroid Build Coastguard Worker                 src_is_causal=False, tgt_is_causal=False,
3544*da0073e9SAndroid Build Coastguard Worker                 memory_is_causal=False):
3545*da0073e9SAndroid Build Coastguard Worker
3546*da0073e9SAndroid Build Coastguard Worker            encoder_input = torch.randn(encoder_input_shape)
3547*da0073e9SAndroid Build Coastguard Worker            decoder_input = torch.randn(decoder_input_shape)
3548*da0073e9SAndroid Build Coastguard Worker            model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers,
3549*da0073e9SAndroid Build Coastguard Worker                                            num_decoder_layers, dim_feedforward, dropout)
3550*da0073e9SAndroid Build Coastguard Worker
3551*da0073e9SAndroid Build Coastguard Worker            if src_mask_len is not None:
3552*da0073e9SAndroid Build Coastguard Worker                src_mask = model.generate_square_subsequent_mask(src_mask_len)
3553*da0073e9SAndroid Build Coastguard Worker            else:
3554*da0073e9SAndroid Build Coastguard Worker                src_mask = None
3555*da0073e9SAndroid Build Coastguard Worker
3556*da0073e9SAndroid Build Coastguard Worker            if tgt_mask_len is not None:
3557*da0073e9SAndroid Build Coastguard Worker                tgt_mask = model.generate_square_subsequent_mask(tgt_mask_len)
3558*da0073e9SAndroid Build Coastguard Worker            else:
3559*da0073e9SAndroid Build Coastguard Worker                tgt_mask = None
3560*da0073e9SAndroid Build Coastguard Worker
3561*da0073e9SAndroid Build Coastguard Worker            if memory_mask_size is not None:
3562*da0073e9SAndroid Build Coastguard Worker                memory_task = torch.rand(memory_mask_size)
3563*da0073e9SAndroid Build Coastguard Worker            else:
3564*da0073e9SAndroid Build Coastguard Worker                memory_task = None
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Worker            if src_key_padding_mask_size is not None:
3567*da0073e9SAndroid Build Coastguard Worker                src_key_padding_mask = torch.rand(src_key_padding_mask_size) >= 0.5
3568*da0073e9SAndroid Build Coastguard Worker            else:
3569*da0073e9SAndroid Build Coastguard Worker                src_key_padding_mask = None
3570*da0073e9SAndroid Build Coastguard Worker
3571*da0073e9SAndroid Build Coastguard Worker            if tgt_key_padding_mask_size is not None:
3572*da0073e9SAndroid Build Coastguard Worker                tgt_key_padding_mask = torch.rand(tgt_key_padding_mask_size) >= 0.5
3573*da0073e9SAndroid Build Coastguard Worker            else:
3574*da0073e9SAndroid Build Coastguard Worker                tgt_key_padding_mask = None
3575*da0073e9SAndroid Build Coastguard Worker
3576*da0073e9SAndroid Build Coastguard Worker            if memory_key_padding_mask_size is not None:
3577*da0073e9SAndroid Build Coastguard Worker                memory_key_padding_mask = torch.rand(memory_key_padding_mask_size) >= 0.5
3578*da0073e9SAndroid Build Coastguard Worker            else:
3579*da0073e9SAndroid Build Coastguard Worker                memory_key_padding_mask = None
3580*da0073e9SAndroid Build Coastguard Worker
3581*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
3582*da0073e9SAndroid Build Coastguard Worker                model(encoder_input, decoder_input,
3583*da0073e9SAndroid Build Coastguard Worker                      src_mask=src_mask,
3584*da0073e9SAndroid Build Coastguard Worker                      tgt_mask=tgt_mask,
3585*da0073e9SAndroid Build Coastguard Worker                      memory_mask=memory_task,
3586*da0073e9SAndroid Build Coastguard Worker                      src_key_padding_mask=src_key_padding_mask,
3587*da0073e9SAndroid Build Coastguard Worker                      tgt_key_padding_mask=tgt_key_padding_mask,
3588*da0073e9SAndroid Build Coastguard Worker                      memory_key_padding_mask=memory_key_padding_mask,
3589*da0073e9SAndroid Build Coastguard Worker                      src_is_causal=src_is_causal,
3590*da0073e9SAndroid Build Coastguard Worker                      tgt_is_causal=tgt_is_causal,
3591*da0073e9SAndroid Build Coastguard Worker                      memory_is_causal=memory_is_causal)
3592*da0073e9SAndroid Build Coastguard Worker
3593*da0073e9SAndroid Build Coastguard Worker
3594*da0073e9SAndroid Build Coastguard Worker        correct_encoder_input_shape = (seq_len, bsz, d_model)
3595*da0073e9SAndroid Build Coastguard Worker        correct_decoder_input_shape = (tgt_len, bsz, d_model)
3596*da0073e9SAndroid Build Coastguard Worker
3597*da0073e9SAndroid Build Coastguard Worker        def update_shape(shape, dim, new_dim_size):
3598*da0073e9SAndroid Build Coastguard Worker            new_shape = list(shape)
3599*da0073e9SAndroid Build Coastguard Worker            new_shape[dim] = new_dim_size
3600*da0073e9SAndroid Build Coastguard Worker            return tuple(new_shape)
3601*da0073e9SAndroid Build Coastguard Worker
3602*da0073e9SAndroid Build Coastguard Worker        # Incorrect encoder_input batch size
3603*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = update_shape(correct_encoder_input_shape, 1, wrong_bsz)
3604*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3605*da0073e9SAndroid Build Coastguard Worker        test(encoder_input_shape, decoder_input_shape)
3606*da0073e9SAndroid Build Coastguard Worker
3607*da0073e9SAndroid Build Coastguard Worker        # Incorrect decoder_input batch size
3608*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3609*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = update_shape(correct_decoder_input_shape, 1, wrong_bsz)
3610*da0073e9SAndroid Build Coastguard Worker        test(encoder_input_shape, decoder_input_shape)
3611*da0073e9SAndroid Build Coastguard Worker
3612*da0073e9SAndroid Build Coastguard Worker        # Incorrect encoder_input input size
3613*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = update_shape(correct_encoder_input_shape, 2, wrong_d_model)
3614*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3615*da0073e9SAndroid Build Coastguard Worker        test(encoder_input_shape, decoder_input_shape)
3616*da0073e9SAndroid Build Coastguard Worker
3617*da0073e9SAndroid Build Coastguard Worker        # Incorrect decoder_input input size
3618*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3619*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = update_shape(correct_decoder_input_shape, 2, wrong_d_model)
3620*da0073e9SAndroid Build Coastguard Worker        test(encoder_input_shape, decoder_input_shape)
3621*da0073e9SAndroid Build Coastguard Worker
3622*da0073e9SAndroid Build Coastguard Worker        # Incorrect nhead
3623*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3624*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3625*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
3626*da0073e9SAndroid Build Coastguard Worker            model = getattr(nn, model_name)(d_model, wrong_nhead, num_encoder_layers,
3627*da0073e9SAndroid Build Coastguard Worker                                            num_decoder_layers, dim_feedforward, dropout)
3628*da0073e9SAndroid Build Coastguard Worker
3629*da0073e9SAndroid Build Coastguard Worker        # Incorrect src_mask
3630*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3631*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3632*da0073e9SAndroid Build Coastguard Worker        wrong_src_mask_size = seq_len + 1
3633*da0073e9SAndroid Build Coastguard Worker        test(encoder_input_shape, decoder_input_shape, src_mask_len=wrong_src_mask_size)
3634*da0073e9SAndroid Build Coastguard Worker
3635*da0073e9SAndroid Build Coastguard Worker        # Incorrect tgt_mask
3636*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3637*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3638*da0073e9SAndroid Build Coastguard Worker        wrong_tgt_mask_size = tgt_len + 1
3639*da0073e9SAndroid Build Coastguard Worker        test(encoder_input_shape, decoder_input_shape, tgt_mask_len=wrong_tgt_mask_size)
3640*da0073e9SAndroid Build Coastguard Worker
3641*da0073e9SAndroid Build Coastguard Worker        # Incorrect memory_mask
3642*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3643*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3644*da0073e9SAndroid Build Coastguard Worker        wrong_tgt_mask_size = tgt_len + 1
3645*da0073e9SAndroid Build Coastguard Worker        test(encoder_input_shape, decoder_input_shape,
3646*da0073e9SAndroid Build Coastguard Worker             memory_mask_size=(wrong_tgt_mask_size, wrong_src_mask_size))
3647*da0073e9SAndroid Build Coastguard Worker
3648*da0073e9SAndroid Build Coastguard Worker        # Incorrect src_key_padding_mask
3649*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3650*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3651*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
3652*da0073e9SAndroid Build Coastguard Worker            test(encoder_input_shape, decoder_input_shape,
3653*da0073e9SAndroid Build Coastguard Worker                 src_key_padding_mask_size=(wrong_bsz, wrong_src_mask_size))
3654*da0073e9SAndroid Build Coastguard Worker
3655*da0073e9SAndroid Build Coastguard Worker        # Incorrect tgt_key_padding_mask
3656*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3657*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3658*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
3659*da0073e9SAndroid Build Coastguard Worker            test(encoder_input_shape, decoder_input_shape,
3660*da0073e9SAndroid Build Coastguard Worker                 tgt_key_padding_mask_size=(wrong_bsz, wrong_tgt_mask_size))
3661*da0073e9SAndroid Build Coastguard Worker
3662*da0073e9SAndroid Build Coastguard Worker        # Incorrect memory_key_padding_mask
3663*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = correct_encoder_input_shape
3664*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = correct_decoder_input_shape
3665*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
3666*da0073e9SAndroid Build Coastguard Worker            test(encoder_input_shape, decoder_input_shape,
3667*da0073e9SAndroid Build Coastguard Worker                 memory_key_padding_mask_size=(wrong_bsz, wrong_src_mask_size))
3668*da0073e9SAndroid Build Coastguard Worker
3669*da0073e9SAndroid Build Coastguard Worker        # Correct activations
3670*da0073e9SAndroid Build Coastguard Worker        for activation in activations:
3671*da0073e9SAndroid Build Coastguard Worker            model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers, num_decoder_layers,
3672*da0073e9SAndroid Build Coastguard Worker                                            dim_feedforward, dropout, activation)
3673*da0073e9SAndroid Build Coastguard Worker        # Incorrect activation
3674*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3675*da0073e9SAndroid Build Coastguard Worker            model = getattr(nn, model_name)(d_model, nhead, num_encoder_layers, num_decoder_layers,
3676*da0073e9SAndroid Build Coastguard Worker                                            dim_feedforward, dropout, wrong_activation)
3677*da0073e9SAndroid Build Coastguard Worker
3678*da0073e9SAndroid Build Coastguard Worker
3679*da0073e9SAndroid Build Coastguard Worker    def test_transformer_layer_args_check(self):
3680*da0073e9SAndroid Build Coastguard Worker        model_names = ['TransformerEncoderLayer', 'TransformerDecoderLayer']
3681*da0073e9SAndroid Build Coastguard Worker        d_model = 128
3682*da0073e9SAndroid Build Coastguard Worker        nhead = 4
3683*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 65
3684*da0073e9SAndroid Build Coastguard Worker        dropout = 0.3
3685*da0073e9SAndroid Build Coastguard Worker        bsz = 3
3686*da0073e9SAndroid Build Coastguard Worker        seq_len = 35
3687*da0073e9SAndroid Build Coastguard Worker        tgt_len = 15
3688*da0073e9SAndroid Build Coastguard Worker        activations = [F.relu, F.gelu]
3689*da0073e9SAndroid Build Coastguard Worker
3690*da0073e9SAndroid Build Coastguard Worker        wrong_activation = "abc"
3691*da0073e9SAndroid Build Coastguard Worker
3692*da0073e9SAndroid Build Coastguard Worker        encoder_input_shape = (seq_len, bsz, d_model)
3693*da0073e9SAndroid Build Coastguard Worker        decoder_input_shape = (tgt_len, bsz, d_model)
3694*da0073e9SAndroid Build Coastguard Worker
3695*da0073e9SAndroid Build Coastguard Worker        encoder_input = torch.randn(encoder_input_shape)
3696*da0073e9SAndroid Build Coastguard Worker        decoder_input = torch.randn(decoder_input_shape)
3697*da0073e9SAndroid Build Coastguard Worker
3698*da0073e9SAndroid Build Coastguard Worker        for model_name in model_names:
3699*da0073e9SAndroid Build Coastguard Worker            for activation in activations:
3700*da0073e9SAndroid Build Coastguard Worker                model = getattr(nn, model_name)(d_model, nhead, dim_feedforward,
3701*da0073e9SAndroid Build Coastguard Worker                                                dropout, activation)
3702*da0073e9SAndroid Build Coastguard Worker        # Incorrect activation
3703*da0073e9SAndroid Build Coastguard Worker        for model_name in model_names:
3704*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
3705*da0073e9SAndroid Build Coastguard Worker                model = getattr(nn, model_name)(d_model, nhead, dim_feedforward,
3706*da0073e9SAndroid Build Coastguard Worker                                                dropout, wrong_activation)
3707*da0073e9SAndroid Build Coastguard Worker
3708*da0073e9SAndroid Build Coastguard Worker    def test_rnn_args_check(self):
3709*da0073e9SAndroid Build Coastguard Worker        input_size = 3
3710*da0073e9SAndroid Build Coastguard Worker        hidden_size = 5
3711*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
3712*da0073e9SAndroid Build Coastguard Worker        batch_size = 4
3713*da0073e9SAndroid Build Coastguard Worker        seq_len = 6
3714*da0073e9SAndroid Build Coastguard Worker        num_directions = 1
3715*da0073e9SAndroid Build Coastguard Worker        bad_size = 7  # prime number so that no size can divide it.
3716*da0073e9SAndroid Build Coastguard Worker
3717*da0073e9SAndroid Build Coastguard Worker        def test(input_shape, hidden_shape, mode):
3718*da0073e9SAndroid Build Coastguard Worker            for input, hidden in get_inputs(input_shape, hidden_shape, mode):
3719*da0073e9SAndroid Build Coastguard Worker                model = getattr(nn, mode)(input_size, hidden_size, num_layers)
3720*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: model(input, hidden))
3721*da0073e9SAndroid Build Coastguard Worker
3722*da0073e9SAndroid Build Coastguard Worker        correct_input_shape = (seq_len, batch_size, input_size)
3723*da0073e9SAndroid Build Coastguard Worker        correct_hidden_shape = (num_layers * num_directions, batch_size, hidden_size)
3724*da0073e9SAndroid Build Coastguard Worker
3725*da0073e9SAndroid Build Coastguard Worker        def update_shape(shape, dim, new_dim_size):
3726*da0073e9SAndroid Build Coastguard Worker            new_shape = list(shape)
3727*da0073e9SAndroid Build Coastguard Worker            new_shape[dim] = new_dim_size
3728*da0073e9SAndroid Build Coastguard Worker            return tuple(new_shape)
3729*da0073e9SAndroid Build Coastguard Worker
3730*da0073e9SAndroid Build Coastguard Worker        def get_inputs(input_shape, hidden_shape, mode):
3731*da0073e9SAndroid Build Coastguard Worker            '''returns list( tuple(input, hidden) )
3732*da0073e9SAndroid Build Coastguard Worker            where input, hidden are inputs to a model'''
3733*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(input_shape)
3734*da0073e9SAndroid Build Coastguard Worker            hidden = torch.randn(hidden_shape)
3735*da0073e9SAndroid Build Coastguard Worker            if mode != 'LSTM':
3736*da0073e9SAndroid Build Coastguard Worker                return [(input, hidden)]
3737*da0073e9SAndroid Build Coastguard Worker            if hidden_shape == correct_hidden_shape:
3738*da0073e9SAndroid Build Coastguard Worker                return [(input, (hidden, hidden))]
3739*da0073e9SAndroid Build Coastguard Worker            good_hidden = torch.randn(correct_hidden_shape)
3740*da0073e9SAndroid Build Coastguard Worker            return [
3741*da0073e9SAndroid Build Coastguard Worker                (input, (hidden, good_hidden)),
3742*da0073e9SAndroid Build Coastguard Worker                (input, (good_hidden, hidden)),
3743*da0073e9SAndroid Build Coastguard Worker            ]
3744*da0073e9SAndroid Build Coastguard Worker
3745*da0073e9SAndroid Build Coastguard Worker        rnn_modes = ['RNN', 'GRU', 'LSTM']
3746*da0073e9SAndroid Build Coastguard Worker        for mode in rnn_modes:
3747*da0073e9SAndroid Build Coastguard Worker            # Incorrect input batch size
3748*da0073e9SAndroid Build Coastguard Worker            input_shape = update_shape(correct_input_shape, 1, bad_size)
3749*da0073e9SAndroid Build Coastguard Worker            hidden_shape = correct_hidden_shape
3750*da0073e9SAndroid Build Coastguard Worker            test(input_shape, hidden_shape, mode)
3751*da0073e9SAndroid Build Coastguard Worker
3752*da0073e9SAndroid Build Coastguard Worker            # Incorrect hidden batch size
3753*da0073e9SAndroid Build Coastguard Worker            input_shape = correct_input_shape
3754*da0073e9SAndroid Build Coastguard Worker            hidden_shape = update_shape(correct_hidden_shape, 1, bad_size)
3755*da0073e9SAndroid Build Coastguard Worker            test(input_shape, hidden_shape, mode)
3756*da0073e9SAndroid Build Coastguard Worker
3757*da0073e9SAndroid Build Coastguard Worker            # Incorrect input size
3758*da0073e9SAndroid Build Coastguard Worker            input_shape = update_shape(correct_input_shape, 2, bad_size)
3759*da0073e9SAndroid Build Coastguard Worker            hidden_shape = correct_hidden_shape
3760*da0073e9SAndroid Build Coastguard Worker            test(input_shape, hidden_shape, mode)
3761*da0073e9SAndroid Build Coastguard Worker
3762*da0073e9SAndroid Build Coastguard Worker            # Incorrect hidden size
3763*da0073e9SAndroid Build Coastguard Worker            input_shape = correct_input_shape
3764*da0073e9SAndroid Build Coastguard Worker            hidden_shape = update_shape(correct_hidden_shape, 2, bad_size)
3765*da0073e9SAndroid Build Coastguard Worker            test(input_shape, hidden_shape, mode)
3766*da0073e9SAndroid Build Coastguard Worker
3767*da0073e9SAndroid Build Coastguard Worker            # Incorrect hidden[0]
3768*da0073e9SAndroid Build Coastguard Worker            input_shape = correct_input_shape
3769*da0073e9SAndroid Build Coastguard Worker            hidden_shape = update_shape(correct_hidden_shape, 0, bad_size)
3770*da0073e9SAndroid Build Coastguard Worker            test(input_shape, hidden_shape, mode)
3771*da0073e9SAndroid Build Coastguard Worker
3772*da0073e9SAndroid Build Coastguard Worker    def test_projections_lstm_args_check(self):
3773*da0073e9SAndroid Build Coastguard Worker        input_size = 3
3774*da0073e9SAndroid Build Coastguard Worker        hidden_size = 5
3775*da0073e9SAndroid Build Coastguard Worker        proj_size = 2
3776*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
3777*da0073e9SAndroid Build Coastguard Worker        batch_size = 4
3778*da0073e9SAndroid Build Coastguard Worker        seq_len = 6
3779*da0073e9SAndroid Build Coastguard Worker        num_directions = 1
3780*da0073e9SAndroid Build Coastguard Worker        bad_size = 7  # prime number so that no size can divide it.
3781*da0073e9SAndroid Build Coastguard Worker
3782*da0073e9SAndroid Build Coastguard Worker        def test(input_shape, hidden_h_shape, hidden_c_shape):
3783*da0073e9SAndroid Build Coastguard Worker            for input, hidden in get_inputs(input_shape, hidden_h_shape, hidden_c_shape):
3784*da0073e9SAndroid Build Coastguard Worker                model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size)
3785*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: model(input, hidden))
3786*da0073e9SAndroid Build Coastguard Worker
3787*da0073e9SAndroid Build Coastguard Worker        correct_input_shape = (seq_len, batch_size, input_size)
3788*da0073e9SAndroid Build Coastguard Worker        correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size)
3789*da0073e9SAndroid Build Coastguard Worker        correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size)
3790*da0073e9SAndroid Build Coastguard Worker
3791*da0073e9SAndroid Build Coastguard Worker        def update_shape(shape, dim, new_dim_size):
3792*da0073e9SAndroid Build Coastguard Worker            new_shape = list(shape)
3793*da0073e9SAndroid Build Coastguard Worker            new_shape[dim] = new_dim_size
3794*da0073e9SAndroid Build Coastguard Worker            return tuple(new_shape)
3795*da0073e9SAndroid Build Coastguard Worker
3796*da0073e9SAndroid Build Coastguard Worker        def get_inputs(input_shape, hidden_h_shape, hidden_c_shape):
3797*da0073e9SAndroid Build Coastguard Worker            '''returns list( tuple(input, hidden) )
3798*da0073e9SAndroid Build Coastguard Worker            where input, hidden are inputs to a model'''
3799*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(input_shape)
3800*da0073e9SAndroid Build Coastguard Worker            hidden_h = torch.randn(hidden_h_shape)
3801*da0073e9SAndroid Build Coastguard Worker            hidden_c = torch.randn(hidden_c_shape)
3802*da0073e9SAndroid Build Coastguard Worker            return [(input, (hidden_h, hidden_c))]
3803*da0073e9SAndroid Build Coastguard Worker
3804*da0073e9SAndroid Build Coastguard Worker        # Incorrect input batch size
3805*da0073e9SAndroid Build Coastguard Worker        input_shape = update_shape(correct_input_shape, 1, bad_size)
3806*da0073e9SAndroid Build Coastguard Worker        test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape)
3807*da0073e9SAndroid Build Coastguard Worker
3808*da0073e9SAndroid Build Coastguard Worker        # Incorrect hidden batch size
3809*da0073e9SAndroid Build Coastguard Worker        input_shape = correct_input_shape
3810*da0073e9SAndroid Build Coastguard Worker        hidden_h_shape = update_shape(correct_hidden_h_shape, 1, bad_size)
3811*da0073e9SAndroid Build Coastguard Worker        hidden_c_shape = update_shape(correct_hidden_c_shape, 1, bad_size)
3812*da0073e9SAndroid Build Coastguard Worker        test(input_shape, hidden_h_shape, hidden_c_shape)
3813*da0073e9SAndroid Build Coastguard Worker
3814*da0073e9SAndroid Build Coastguard Worker        # Incorrect input size
3815*da0073e9SAndroid Build Coastguard Worker        input_shape = update_shape(correct_input_shape, 2, bad_size)
3816*da0073e9SAndroid Build Coastguard Worker        test(input_shape, correct_hidden_h_shape, correct_hidden_c_shape)
3817*da0073e9SAndroid Build Coastguard Worker
3818*da0073e9SAndroid Build Coastguard Worker        # Incorrect hidden size
3819*da0073e9SAndroid Build Coastguard Worker        input_shape = correct_input_shape
3820*da0073e9SAndroid Build Coastguard Worker        hidden_h_shape = update_shape(correct_hidden_h_shape, 2, bad_size)
3821*da0073e9SAndroid Build Coastguard Worker        hidden_c_shape = update_shape(correct_hidden_c_shape, 2, bad_size)
3822*da0073e9SAndroid Build Coastguard Worker        test(input_shape, hidden_h_shape, hidden_c_shape)
3823*da0073e9SAndroid Build Coastguard Worker
3824*da0073e9SAndroid Build Coastguard Worker        # Incorrect hidden[0]
3825*da0073e9SAndroid Build Coastguard Worker        input_shape = correct_input_shape
3826*da0073e9SAndroid Build Coastguard Worker        hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size)
3827*da0073e9SAndroid Build Coastguard Worker        hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size)
3828*da0073e9SAndroid Build Coastguard Worker        test(input_shape, hidden_h_shape, hidden_c_shape)
3829*da0073e9SAndroid Build Coastguard Worker
3830*da0073e9SAndroid Build Coastguard Worker        # Incorrect proj size = hidden size
3831*da0073e9SAndroid Build Coastguard Worker        input_shape = correct_input_shape
3832*da0073e9SAndroid Build Coastguard Worker        hidden_h_shape = update_shape(correct_hidden_h_shape, 0, hidden_size)
3833*da0073e9SAndroid Build Coastguard Worker        hidden_c_shape = correct_hidden_c_shape
3834*da0073e9SAndroid Build Coastguard Worker        test(input_shape, hidden_h_shape, hidden_c_shape)
3835*da0073e9SAndroid Build Coastguard Worker
3836*da0073e9SAndroid Build Coastguard Worker        # Incorrect proj size != hidden size
3837*da0073e9SAndroid Build Coastguard Worker        input_shape = correct_input_shape
3838*da0073e9SAndroid Build Coastguard Worker        hidden_h_shape = update_shape(correct_hidden_h_shape, 0, bad_size)
3839*da0073e9SAndroid Build Coastguard Worker        hidden_c_shape = correct_hidden_c_shape
3840*da0073e9SAndroid Build Coastguard Worker        test(input_shape, hidden_h_shape, hidden_c_shape)
3841*da0073e9SAndroid Build Coastguard Worker
3842*da0073e9SAndroid Build Coastguard Worker        # Incorrect cell size != hidden size
3843*da0073e9SAndroid Build Coastguard Worker        input_shape = correct_input_shape
3844*da0073e9SAndroid Build Coastguard Worker        hidden_h_shape = correct_hidden_h_shape
3845*da0073e9SAndroid Build Coastguard Worker        hidden_c_shape = update_shape(correct_hidden_c_shape, 0, bad_size)
3846*da0073e9SAndroid Build Coastguard Worker        test(input_shape, hidden_h_shape, hidden_c_shape)
3847*da0073e9SAndroid Build Coastguard Worker
3848*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
3849*da0073e9SAndroid Build Coastguard Worker    def test_rnn_check_device(self):
3850*da0073e9SAndroid Build Coastguard Worker        import copy
3851*da0073e9SAndroid Build Coastguard Worker        input_size = 3
3852*da0073e9SAndroid Build Coastguard Worker        hidden_size = 5
3853*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
3854*da0073e9SAndroid Build Coastguard Worker        batch_size = 4
3855*da0073e9SAndroid Build Coastguard Worker        seq_len = 6
3856*da0073e9SAndroid Build Coastguard Worker        num_directions = 1
3857*da0073e9SAndroid Build Coastguard Worker
3858*da0073e9SAndroid Build Coastguard Worker        correct_input_shape = (seq_len, batch_size, input_size)
3859*da0073e9SAndroid Build Coastguard Worker        correct_hidden_shape = (num_layers * num_directions, batch_size, hidden_size)
3860*da0073e9SAndroid Build Coastguard Worker        rnn_modes = ['RNN', 'GRU', 'LSTM']
3861*da0073e9SAndroid Build Coastguard Worker
3862*da0073e9SAndroid Build Coastguard Worker        for mode in rnn_modes:
3863*da0073e9SAndroid Build Coastguard Worker            model = getattr(nn, mode)(input_size, hidden_size, num_layers)
3864*da0073e9SAndroid Build Coastguard Worker            model_cuda = copy.deepcopy(model).to('cuda:0')
3865*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(correct_input_shape)
3866*da0073e9SAndroid Build Coastguard Worker            hidden = torch.randn(correct_hidden_shape)
3867*da0073e9SAndroid Build Coastguard Worker
3868*da0073e9SAndroid Build Coastguard Worker            # input and weights are not at the same device
3869*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError,
3870*da0073e9SAndroid Build Coastguard Worker                                        "Input and parameter tensors are not at the same device"):
3871*da0073e9SAndroid Build Coastguard Worker                model(input.to('cuda:0'))
3872*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError,
3873*da0073e9SAndroid Build Coastguard Worker                                        "Input and parameter tensors are not at the same device"):
3874*da0073e9SAndroid Build Coastguard Worker                model_cuda(input)
3875*da0073e9SAndroid Build Coastguard Worker
3876*da0073e9SAndroid Build Coastguard Worker            # input and hiddens are not at the same device
3877*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError,
3878*da0073e9SAndroid Build Coastguard Worker                                        r"Input and hidden tensors are not at the same device"):
3879*da0073e9SAndroid Build Coastguard Worker                if mode == 'LSTM':
3880*da0073e9SAndroid Build Coastguard Worker                    model(input, (hidden.to('cuda:0'), hidden.to('cuda:0')))
3881*da0073e9SAndroid Build Coastguard Worker                else:
3882*da0073e9SAndroid Build Coastguard Worker                    model(input, (hidden.to('cuda:0')))
3883*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError,
3884*da0073e9SAndroid Build Coastguard Worker                                        r"Input and hidden tensors are not at the same device"):
3885*da0073e9SAndroid Build Coastguard Worker                if mode == 'LSTM':
3886*da0073e9SAndroid Build Coastguard Worker                    model_cuda(input.to('cuda:0'), (hidden, hidden))
3887*da0073e9SAndroid Build Coastguard Worker                else:
3888*da0073e9SAndroid Build Coastguard Worker                    model_cuda(input.to('cuda:0'), (hidden))
3889*da0073e9SAndroid Build Coastguard Worker
3890*da0073e9SAndroid Build Coastguard Worker            # hidden tensors are not at the same CUDA device
3891*da0073e9SAndroid Build Coastguard Worker            if mode == 'LSTM':
3892*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError,
3893*da0073e9SAndroid Build Coastguard Worker                                            "Input and hidden tensors are not at the same device"):
3894*da0073e9SAndroid Build Coastguard Worker                    model(input.to('cuda:0'), (hidden.to('cuda:0'), hidden.to('cuda:1')))
3895*da0073e9SAndroid Build Coastguard Worker
3896*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
3897*da0073e9SAndroid Build Coastguard Worker    def test_projections_lstm_check_device(self):
3898*da0073e9SAndroid Build Coastguard Worker        input_size = 3
3899*da0073e9SAndroid Build Coastguard Worker        hidden_size = 5
3900*da0073e9SAndroid Build Coastguard Worker        proj_size = 2
3901*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
3902*da0073e9SAndroid Build Coastguard Worker        batch_size = 4
3903*da0073e9SAndroid Build Coastguard Worker        seq_len = 6
3904*da0073e9SAndroid Build Coastguard Worker        num_directions = 1
3905*da0073e9SAndroid Build Coastguard Worker
3906*da0073e9SAndroid Build Coastguard Worker        correct_input_shape = (seq_len, batch_size, input_size)
3907*da0073e9SAndroid Build Coastguard Worker        correct_hidden_h_shape = (num_layers * num_directions, batch_size, proj_size)
3908*da0073e9SAndroid Build Coastguard Worker        correct_hidden_c_shape = (num_layers * num_directions, batch_size, hidden_size)
3909*da0073e9SAndroid Build Coastguard Worker
3910*da0073e9SAndroid Build Coastguard Worker        model = nn.LSTM(input_size, hidden_size, num_layers, proj_size=proj_size)
3911*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(correct_input_shape)
3912*da0073e9SAndroid Build Coastguard Worker        hidden_h = torch.randn(correct_hidden_h_shape)
3913*da0073e9SAndroid Build Coastguard Worker        hidden_c = torch.randn(correct_hidden_c_shape)
3914*da0073e9SAndroid Build Coastguard Worker
3915*da0073e9SAndroid Build Coastguard Worker        # input and weights are not at the same device
3916*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
3917*da0073e9SAndroid Build Coastguard Worker                                    "Input and parameter tensors are not at the same device"):
3918*da0073e9SAndroid Build Coastguard Worker            model(input.to('cuda:0'))
3919*da0073e9SAndroid Build Coastguard Worker
3920*da0073e9SAndroid Build Coastguard Worker        # input and hiddens are not at the same device
3921*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
3922*da0073e9SAndroid Build Coastguard Worker                                    r"Input and hidden tensors are not at the same device"):
3923*da0073e9SAndroid Build Coastguard Worker            model(input, (hidden_h.to('cuda:0'), hidden_c.to('cuda:0')))
3924*da0073e9SAndroid Build Coastguard Worker
3925*da0073e9SAndroid Build Coastguard Worker        # hidden tensors are not at the same CUDA device
3926*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
3927*da0073e9SAndroid Build Coastguard Worker                                    "Input and hidden tensors are not at the same device"):
3928*da0073e9SAndroid Build Coastguard Worker            model(input.to('cuda:0'), (hidden_h.to('cuda:0'), hidden_c.to('cuda:1')))
3929*da0073e9SAndroid Build Coastguard Worker
3930*da0073e9SAndroid Build Coastguard Worker    def test_rnn_initial_hidden_state(self):
3931*da0073e9SAndroid Build Coastguard Worker        rnn_modes = ['RNN', 'GRU', 'LSTM']
3932*da0073e9SAndroid Build Coastguard Worker        for mode in rnn_modes:
3933*da0073e9SAndroid Build Coastguard Worker            rnn = getattr(nn, mode)(30, 20, 2)
3934*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(10, 32, 30)
3935*da0073e9SAndroid Build Coastguard Worker            hidden = torch.zeros(2, 32, 20)
3936*da0073e9SAndroid Build Coastguard Worker
3937*da0073e9SAndroid Build Coastguard Worker            if mode == 'LSTM':
3938*da0073e9SAndroid Build Coastguard Worker                hidden = (hidden, hidden)
3939*da0073e9SAndroid Build Coastguard Worker            output1, hidden1 = rnn(input, hidden)
3940*da0073e9SAndroid Build Coastguard Worker            output2, hidden2 = rnn(input)
3941*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output1, output2)
3942*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hidden1, hidden2)
3943*da0073e9SAndroid Build Coastguard Worker
3944*da0073e9SAndroid Build Coastguard Worker    def test_projections_lstm_initial_hidden_state(self):
3945*da0073e9SAndroid Build Coastguard Worker        for bidir in [False, True]:
3946*da0073e9SAndroid Build Coastguard Worker            rnn = nn.LSTM(30, 20, 2, bidirectional=bidir, proj_size=10)
3947*da0073e9SAndroid Build Coastguard Worker            num_dirs = 2 if bidir else 1
3948*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(10, 32, 30)
3949*da0073e9SAndroid Build Coastguard Worker            hidden_h = torch.zeros(2 * num_dirs, 32, 10)
3950*da0073e9SAndroid Build Coastguard Worker            hidden_c = torch.zeros(2 * num_dirs, 32, 20)
3951*da0073e9SAndroid Build Coastguard Worker            hidden = (hidden_h, hidden_c)
3952*da0073e9SAndroid Build Coastguard Worker            output1, hidden1 = rnn(input, hidden)
3953*da0073e9SAndroid Build Coastguard Worker            output2, hidden2 = rnn(input)
3954*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output1, output2)
3955*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hidden1, hidden2)
3956*da0073e9SAndroid Build Coastguard Worker
3957*da0073e9SAndroid Build Coastguard Worker    def test_projections_errors_on_gru_and_rnn(self):
3958*da0073e9SAndroid Build Coastguard Worker        error_msg = "proj_size argument is only supported for LSTM, not RNN or GRU"
3959*da0073e9SAndroid Build Coastguard Worker        for mode in ['RNN', 'GRU']:
3960*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, error_msg):
3961*da0073e9SAndroid Build Coastguard Worker                rnn = getattr(nn, mode)(30, 20, 2, proj_size=10)
3962*da0073e9SAndroid Build Coastguard Worker
3963*da0073e9SAndroid Build Coastguard Worker    def _test_RNN_cpu_vs_cudnn(self, dropout, dtype=torch.double):
3964*da0073e9SAndroid Build Coastguard Worker
3965*da0073e9SAndroid Build Coastguard Worker        def forward_backward(cuda, rnn, input_val, grad_output, weights_val, hx_val, grad_hy,
3966*da0073e9SAndroid Build Coastguard Worker                             cx_val=None, grad_cy=None):
3967*da0073e9SAndroid Build Coastguard Worker            is_lstm = isinstance(rnn, nn.LSTM)
3968*da0073e9SAndroid Build Coastguard Worker
3969*da0073e9SAndroid Build Coastguard Worker            for x_layer, y_layer in zip(rnn.all_weights, weights_val):
3970*da0073e9SAndroid Build Coastguard Worker                for x, y in zip(x_layer, y_layer):
3971*da0073e9SAndroid Build Coastguard Worker                    x.data.copy_(y.data)
3972*da0073e9SAndroid Build Coastguard Worker
3973*da0073e9SAndroid Build Coastguard Worker            if isinstance(input_val, rnn_utils.PackedSequence):
3974*da0073e9SAndroid Build Coastguard Worker                input = rnn_utils.PackedSequence(
3975*da0073e9SAndroid Build Coastguard Worker                    input_val.data.data.requires_grad_(True), input_val.batch_sizes)
3976*da0073e9SAndroid Build Coastguard Worker                input_var = input.data
3977*da0073e9SAndroid Build Coastguard Worker            else:
3978*da0073e9SAndroid Build Coastguard Worker                input = input_val.clone().requires_grad_(True)
3979*da0073e9SAndroid Build Coastguard Worker                input_var = input
3980*da0073e9SAndroid Build Coastguard Worker            if is_lstm:
3981*da0073e9SAndroid Build Coastguard Worker                if cx_val is None:
3982*da0073e9SAndroid Build Coastguard Worker                    hx = (hx_val.clone().requires_grad_(True),
3983*da0073e9SAndroid Build Coastguard Worker                          hx_val.add(1).requires_grad_(True))
3984*da0073e9SAndroid Build Coastguard Worker                else:
3985*da0073e9SAndroid Build Coastguard Worker                    hx = (hx_val.clone().requires_grad_(True),
3986*da0073e9SAndroid Build Coastguard Worker                          cx_val.add(1).requires_grad_(True))
3987*da0073e9SAndroid Build Coastguard Worker            else:
3988*da0073e9SAndroid Build Coastguard Worker                hx = hx_val.clone().requires_grad_(True)
3989*da0073e9SAndroid Build Coastguard Worker
3990*da0073e9SAndroid Build Coastguard Worker            if cuda:
3991*da0073e9SAndroid Build Coastguard Worker                rnn.cuda()
3992*da0073e9SAndroid Build Coastguard Worker                input_var.data = input_var.data.cuda()
3993*da0073e9SAndroid Build Coastguard Worker                if is_lstm:
3994*da0073e9SAndroid Build Coastguard Worker                    hx[0].data = hx[0].data.cuda()
3995*da0073e9SAndroid Build Coastguard Worker                    hx[1].data = hx[1].data.cuda()
3996*da0073e9SAndroid Build Coastguard Worker                else:
3997*da0073e9SAndroid Build Coastguard Worker                    hx.data = hx.data.cuda()
3998*da0073e9SAndroid Build Coastguard Worker                grad_hy = grad_hy.cuda()
3999*da0073e9SAndroid Build Coastguard Worker                if grad_cy is not None:
4000*da0073e9SAndroid Build Coastguard Worker                    grad_cy = grad_cy.cuda()
4001*da0073e9SAndroid Build Coastguard Worker                grad_output = grad_output.cuda()
4002*da0073e9SAndroid Build Coastguard Worker
4003*da0073e9SAndroid Build Coastguard Worker            output, hy = rnn(input, hx)
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker            if isinstance(output, rnn_utils.PackedSequence):
4006*da0073e9SAndroid Build Coastguard Worker                output = output.data
4007*da0073e9SAndroid Build Coastguard Worker
4008*da0073e9SAndroid Build Coastguard Worker            if is_lstm:
4009*da0073e9SAndroid Build Coastguard Worker                if grad_cy is None:
4010*da0073e9SAndroid Build Coastguard Worker                    torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_hy + 1])
4011*da0073e9SAndroid Build Coastguard Worker                else:
4012*da0073e9SAndroid Build Coastguard Worker                    torch.autograd.backward([output, hy[0], hy[1]], [grad_output, grad_hy, grad_cy + 1])
4013*da0073e9SAndroid Build Coastguard Worker            else:
4014*da0073e9SAndroid Build Coastguard Worker                torch.autograd.backward([output, hy], [grad_output, grad_hy])
4015*da0073e9SAndroid Build Coastguard Worker
4016*da0073e9SAndroid Build Coastguard Worker            return {'output': output.data,
4017*da0073e9SAndroid Build Coastguard Worker                    'hy': hy[0].data if is_lstm else hy.data,
4018*da0073e9SAndroid Build Coastguard Worker                    'weights': rnn.all_weights,
4019*da0073e9SAndroid Build Coastguard Worker                    'grad_input': input_var.grad.data,
4020*da0073e9SAndroid Build Coastguard Worker                    'grad_hx': hx[0].grad.data if is_lstm else hx.grad.data,
4021*da0073e9SAndroid Build Coastguard Worker                    'cy': hy[1].data if is_lstm else None,
4022*da0073e9SAndroid Build Coastguard Worker                    'grad_cx': hx[1].grad.data if is_lstm else None}
4023*da0073e9SAndroid Build Coastguard Worker
4024*da0073e9SAndroid Build Coastguard Worker        input_size = 10
4025*da0073e9SAndroid Build Coastguard Worker        hidden_size = 6
4026*da0073e9SAndroid Build Coastguard Worker        proj_size = 3
4027*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
4028*da0073e9SAndroid Build Coastguard Worker        seq_length = 7
4029*da0073e9SAndroid Build Coastguard Worker        batch = 6
4030*da0073e9SAndroid Build Coastguard Worker
4031*da0073e9SAndroid Build Coastguard Worker        def make_noncontig(tensor):
4032*da0073e9SAndroid Build Coastguard Worker            ndim = tensor.dim()
4033*da0073e9SAndroid Build Coastguard Worker            return torch.stack([tensor.clone().zero_(), tensor], ndim).select(ndim, 1)
4034*da0073e9SAndroid Build Coastguard Worker
4035*da0073e9SAndroid Build Coastguard Worker        def compare_cpu_gpu(outputs_cpu, outputs_gpu):
4036*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(list(outputs_cpu.keys()), list(outputs_gpu.keys()))
4037*da0073e9SAndroid Build Coastguard Worker            for key in outputs_cpu.keys():
4038*da0073e9SAndroid Build Coastguard Worker                if key != 'weights':
4039*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(outputs_cpu[key], outputs_gpu[key], atol=5e-5, rtol=0, msg=key)
4040*da0073e9SAndroid Build Coastguard Worker
4041*da0073e9SAndroid Build Coastguard Worker            # check grad weights separately, as nested dict
4042*da0073e9SAndroid Build Coastguard Worker            for cpu_layer_weight, gpu_layer_weight in zip(outputs_cpu['weights'], outputs_gpu['weights']):
4043*da0073e9SAndroid Build Coastguard Worker                for (cpu_weight, gpu_weight) in zip(cpu_layer_weight, gpu_layer_weight):
4044*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cpu_weight.grad.data, gpu_weight.grad.data, atol=5e-5, rtol=0)
4045*da0073e9SAndroid Build Coastguard Worker
4046*da0073e9SAndroid Build Coastguard Worker        for module in (nn.RNN, nn.LSTM, nn.GRU):
4047*da0073e9SAndroid Build Coastguard Worker            for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \
4048*da0073e9SAndroid Build Coastguard Worker                    in product((True, False), repeat=6):
4049*da0073e9SAndroid Build Coastguard Worker
4050*da0073e9SAndroid Build Coastguard Worker                num_directions = 2 if bidirectional else 1
4051*da0073e9SAndroid Build Coastguard Worker                if batch_first:
4052*da0073e9SAndroid Build Coastguard Worker                    input_val = torch.randn(batch, seq_length, input_size, dtype=dtype)
4053*da0073e9SAndroid Build Coastguard Worker                    grad_output = torch.randn(batch, seq_length, hidden_size * num_directions, dtype=dtype)
4054*da0073e9SAndroid Build Coastguard Worker                else:
4055*da0073e9SAndroid Build Coastguard Worker                    input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
4056*da0073e9SAndroid Build Coastguard Worker                    grad_output = torch.randn(seq_length, batch, hidden_size * num_directions, dtype=dtype)
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker                hx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4059*da0073e9SAndroid Build Coastguard Worker                grad_hy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4060*da0073e9SAndroid Build Coastguard Worker
4061*da0073e9SAndroid Build Coastguard Worker                if not contig:
4062*da0073e9SAndroid Build Coastguard Worker                    grad_output = make_noncontig(grad_output)
4063*da0073e9SAndroid Build Coastguard Worker                    grad_hy = make_noncontig(grad_hy)
4064*da0073e9SAndroid Build Coastguard Worker                    input_var = make_noncontig(input_val)
4065*da0073e9SAndroid Build Coastguard Worker                    hx_val = make_noncontig(hx_val)
4066*da0073e9SAndroid Build Coastguard Worker
4067*da0073e9SAndroid Build Coastguard Worker                if variable_len:
4068*da0073e9SAndroid Build Coastguard Worker                    lengths = [7, 5, 5, 2, 1, 1]
4069*da0073e9SAndroid Build Coastguard Worker                    if lens_as_tensor:
4070*da0073e9SAndroid Build Coastguard Worker                        lengths = torch.tensor(lengths, dtype=torch.long)
4071*da0073e9SAndroid Build Coastguard Worker                    input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first)
4072*da0073e9SAndroid Build Coastguard Worker                    grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data
4073*da0073e9SAndroid Build Coastguard Worker
4074*da0073e9SAndroid Build Coastguard Worker                rnn = module(input_size,
4075*da0073e9SAndroid Build Coastguard Worker                             hidden_size,
4076*da0073e9SAndroid Build Coastguard Worker                             num_layers,
4077*da0073e9SAndroid Build Coastguard Worker                             bias=bias,
4078*da0073e9SAndroid Build Coastguard Worker                             dropout=dropout,
4079*da0073e9SAndroid Build Coastguard Worker                             bidirectional=bidirectional,
4080*da0073e9SAndroid Build Coastguard Worker                             batch_first=batch_first).to(dtype)
4081*da0073e9SAndroid Build Coastguard Worker
4082*da0073e9SAndroid Build Coastguard Worker                outputs_cpu = forward_backward(
4083*da0073e9SAndroid Build Coastguard Worker                    False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4084*da0073e9SAndroid Build Coastguard Worker
4085*da0073e9SAndroid Build Coastguard Worker                rnn_gpu = module(input_size,
4086*da0073e9SAndroid Build Coastguard Worker                                 hidden_size,
4087*da0073e9SAndroid Build Coastguard Worker                                 num_layers,
4088*da0073e9SAndroid Build Coastguard Worker                                 bias=bias,
4089*da0073e9SAndroid Build Coastguard Worker                                 dropout=dropout,
4090*da0073e9SAndroid Build Coastguard Worker                                 bidirectional=bidirectional,
4091*da0073e9SAndroid Build Coastguard Worker                                 batch_first=batch_first).to(dtype)
4092*da0073e9SAndroid Build Coastguard Worker
4093*da0073e9SAndroid Build Coastguard Worker                outputs_gpu = forward_backward(
4094*da0073e9SAndroid Build Coastguard Worker                    True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4095*da0073e9SAndroid Build Coastguard Worker
4096*da0073e9SAndroid Build Coastguard Worker                compare_cpu_gpu(outputs_cpu, outputs_gpu)
4097*da0073e9SAndroid Build Coastguard Worker
4098*da0073e9SAndroid Build Coastguard Worker        for nonlinearity in ('tanh', 'relu'):
4099*da0073e9SAndroid Build Coastguard Worker            hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
4100*da0073e9SAndroid Build Coastguard Worker            input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
4101*da0073e9SAndroid Build Coastguard Worker            grad_output = torch.randn(
4102*da0073e9SAndroid Build Coastguard Worker                seq_length, batch, hidden_size * num_directions, dtype=dtype)
4103*da0073e9SAndroid Build Coastguard Worker            grad_hy = torch.randn(
4104*da0073e9SAndroid Build Coastguard Worker                num_layers * num_directions, batch, hidden_size, dtype=dtype)
4105*da0073e9SAndroid Build Coastguard Worker
4106*da0073e9SAndroid Build Coastguard Worker            rnn = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype)
4107*da0073e9SAndroid Build Coastguard Worker            outputs_cpu = forward_backward(False, rnn, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4108*da0073e9SAndroid Build Coastguard Worker
4109*da0073e9SAndroid Build Coastguard Worker            rnn_gpu = nn.RNN(input_size, hidden_size, num_layers, bias=bias, nonlinearity=nonlinearity).to(dtype)
4110*da0073e9SAndroid Build Coastguard Worker            outputs_gpu = forward_backward(True, rnn_gpu, input_val, grad_output, rnn.all_weights, hx_val, grad_hy)
4111*da0073e9SAndroid Build Coastguard Worker
4112*da0073e9SAndroid Build Coastguard Worker            compare_cpu_gpu(outputs_cpu, outputs_gpu)
4113*da0073e9SAndroid Build Coastguard Worker
4114*da0073e9SAndroid Build Coastguard Worker        # checking LSTM with projections
4115*da0073e9SAndroid Build Coastguard Worker        for bias, bidirectional, batch_first, contig, variable_len, lens_as_tensor \
4116*da0073e9SAndroid Build Coastguard Worker                in product((True, False), repeat=6):
4117*da0073e9SAndroid Build Coastguard Worker            num_directions = 2 if bidirectional else 1
4118*da0073e9SAndroid Build Coastguard Worker            if batch_first:
4119*da0073e9SAndroid Build Coastguard Worker                input_val = torch.randn(batch, seq_length, input_size, dtype=dtype)
4120*da0073e9SAndroid Build Coastguard Worker                grad_output = torch.randn(batch, seq_length, proj_size * num_directions, dtype=dtype)
4121*da0073e9SAndroid Build Coastguard Worker            else:
4122*da0073e9SAndroid Build Coastguard Worker                input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
4123*da0073e9SAndroid Build Coastguard Worker                grad_output = torch.randn(seq_length, batch, proj_size * num_directions, dtype=dtype)
4124*da0073e9SAndroid Build Coastguard Worker
4125*da0073e9SAndroid Build Coastguard Worker            hx_val = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype)
4126*da0073e9SAndroid Build Coastguard Worker            cx_val = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4127*da0073e9SAndroid Build Coastguard Worker            grad_hy = torch.randn(num_layers * num_directions, batch, proj_size, dtype=dtype)
4128*da0073e9SAndroid Build Coastguard Worker            grad_cy = torch.randn(num_layers * num_directions, batch, hidden_size, dtype=dtype)
4129*da0073e9SAndroid Build Coastguard Worker
4130*da0073e9SAndroid Build Coastguard Worker            if not contig:
4131*da0073e9SAndroid Build Coastguard Worker                grad_output = make_noncontig(grad_output)
4132*da0073e9SAndroid Build Coastguard Worker                grad_hy = make_noncontig(grad_hy)
4133*da0073e9SAndroid Build Coastguard Worker                grad_cy = make_noncontig(grad_cy)
4134*da0073e9SAndroid Build Coastguard Worker                input_var = make_noncontig(input_val)
4135*da0073e9SAndroid Build Coastguard Worker                hx_val = make_noncontig(hx_val)
4136*da0073e9SAndroid Build Coastguard Worker                cx_val = make_noncontig(cx_val)
4137*da0073e9SAndroid Build Coastguard Worker
4138*da0073e9SAndroid Build Coastguard Worker            if variable_len:
4139*da0073e9SAndroid Build Coastguard Worker                lengths = [7, 5, 5, 2, 1, 1]
4140*da0073e9SAndroid Build Coastguard Worker                if lens_as_tensor:
4141*da0073e9SAndroid Build Coastguard Worker                    lengths = torch.tensor(lengths, dtype=torch.long)
4142*da0073e9SAndroid Build Coastguard Worker                input_val = rnn_utils.pack_padded_sequence(input_val, lengths, batch_first=batch_first)
4143*da0073e9SAndroid Build Coastguard Worker                grad_output = rnn_utils.pack_padded_sequence(grad_output, lengths, batch_first=batch_first).data
4144*da0073e9SAndroid Build Coastguard Worker
4145*da0073e9SAndroid Build Coastguard Worker            rnn = nn.LSTM(input_size,
4146*da0073e9SAndroid Build Coastguard Worker                          hidden_size,
4147*da0073e9SAndroid Build Coastguard Worker                          num_layers,
4148*da0073e9SAndroid Build Coastguard Worker                          bias=bias,
4149*da0073e9SAndroid Build Coastguard Worker                          dropout=dropout,
4150*da0073e9SAndroid Build Coastguard Worker                          bidirectional=bidirectional,
4151*da0073e9SAndroid Build Coastguard Worker                          batch_first=batch_first,
4152*da0073e9SAndroid Build Coastguard Worker                          proj_size=proj_size).to(dtype)
4153*da0073e9SAndroid Build Coastguard Worker
4154*da0073e9SAndroid Build Coastguard Worker            outputs_cpu = forward_backward(
4155*da0073e9SAndroid Build Coastguard Worker                False, rnn, input_val, grad_output, rnn.all_weights,
4156*da0073e9SAndroid Build Coastguard Worker                hx_val, grad_hy, cx_val, grad_cy)
4157*da0073e9SAndroid Build Coastguard Worker
4158*da0073e9SAndroid Build Coastguard Worker            rnn_gpu = nn.LSTM(input_size,
4159*da0073e9SAndroid Build Coastguard Worker                              hidden_size,
4160*da0073e9SAndroid Build Coastguard Worker                              num_layers,
4161*da0073e9SAndroid Build Coastguard Worker                              bias=bias,
4162*da0073e9SAndroid Build Coastguard Worker                              dropout=dropout,
4163*da0073e9SAndroid Build Coastguard Worker                              bidirectional=bidirectional,
4164*da0073e9SAndroid Build Coastguard Worker                              batch_first=batch_first,
4165*da0073e9SAndroid Build Coastguard Worker                              proj_size=proj_size).to(dtype)
4166*da0073e9SAndroid Build Coastguard Worker
4167*da0073e9SAndroid Build Coastguard Worker            outputs_gpu = forward_backward(
4168*da0073e9SAndroid Build Coastguard Worker                True, rnn_gpu, input_val, grad_output, rnn.all_weights,
4169*da0073e9SAndroid Build Coastguard Worker                hx_val, grad_hy, cx_val, grad_cy)
4170*da0073e9SAndroid Build Coastguard Worker            compare_cpu_gpu(outputs_cpu, outputs_gpu)
4171*da0073e9SAndroid Build Coastguard Worker
4172*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4173*da0073e9SAndroid Build Coastguard Worker    def test_RNN_cpu_vs_cudnn_no_dropout(self):
4174*da0073e9SAndroid Build Coastguard Worker        dtype = torch.double
4175*da0073e9SAndroid Build Coastguard Worker        self._test_RNN_cpu_vs_cudnn(0, dtype)
4176*da0073e9SAndroid Build Coastguard Worker
4177*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4178*da0073e9SAndroid Build Coastguard Worker    def test_RNN_cpu_vs_cudnn_with_dropout(self):
4179*da0073e9SAndroid Build Coastguard Worker        # Because of dropout randomness, can only compare dropout=0 and dropout=1
4180*da0073e9SAndroid Build Coastguard Worker        self._test_RNN_cpu_vs_cudnn(1)
4181*da0073e9SAndroid Build Coastguard Worker
4182*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4183*da0073e9SAndroid Build Coastguard Worker    def test_RNN_cudnn_weight_norm(self):
4184*da0073e9SAndroid Build Coastguard Worker        input_size = 10
4185*da0073e9SAndroid Build Coastguard Worker        hidden_size = 6
4186*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
4187*da0073e9SAndroid Build Coastguard Worker        seq_length = 7
4188*da0073e9SAndroid Build Coastguard Worker        batch = 6
4189*da0073e9SAndroid Build Coastguard Worker
4190*da0073e9SAndroid Build Coastguard Worker        # runs on CPU to acquire expected output
4191*da0073e9SAndroid Build Coastguard Worker        def check_weight_norm(m, name):
4192*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(seq_length, batch, input_size)
4193*da0073e9SAndroid Build Coastguard Worker            expected_output = m(input)
4194*da0073e9SAndroid Build Coastguard Worker
4195*da0073e9SAndroid Build Coastguard Worker            # adds weight normalization
4196*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.weight_norm(m, name=name)
4197*da0073e9SAndroid Build Coastguard Worker
4198*da0073e9SAndroid Build Coastguard Worker            # moves to CUDA
4199*da0073e9SAndroid Build Coastguard Worker            m = m.cuda()
4200*da0073e9SAndroid Build Coastguard Worker            input = input.cuda()
4201*da0073e9SAndroid Build Coastguard Worker
4202*da0073e9SAndroid Build Coastguard Worker            # otherwise, subsequent warnings will be hidden, and further tests rely on them
4203*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
4204*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output)
4205*da0073e9SAndroid Build Coastguard Worker
4206*da0073e9SAndroid Build Coastguard Worker            # remove weight norm
4207*da0073e9SAndroid Build Coastguard Worker            m = torch.nn.utils.remove_weight_norm(m, name=name)
4208*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(input), expected_output)
4209*da0073e9SAndroid Build Coastguard Worker
4210*da0073e9SAndroid Build Coastguard Worker        check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers), 'weight_hh_l0')
4211*da0073e9SAndroid Build Coastguard Worker        check_weight_norm(nn.LSTM(input_size, hidden_size, num_layers, proj_size=3), 'weight_hr_l0')
4212*da0073e9SAndroid Build Coastguard Worker
4213*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
4214*da0073e9SAndroid Build Coastguard Worker    def test_partial_flat_weights(self):
4215*da0073e9SAndroid Build Coastguard Worker        input_size = 10
4216*da0073e9SAndroid Build Coastguard Worker        hidden_size = 6
4217*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
4218*da0073e9SAndroid Build Coastguard Worker
4219*da0073e9SAndroid Build Coastguard Worker        m = nn.LSTM(input_size, hidden_size, num_layers)
4220*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 2, 10)
4221*da0073e9SAndroid Build Coastguard Worker        out_expected = m(inp)
4222*da0073e9SAndroid Build Coastguard Worker        # deletes an attribute of original LSTM
4223*da0073e9SAndroid Build Coastguard Worker        weight_orig = m.weight_hh_l0
4224*da0073e9SAndroid Build Coastguard Worker        del m.weight_hh_l0
4225*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(m, "weight_hh_l0"))
4226*da0073e9SAndroid Build Coastguard Worker        # verifies that moving to CUDA with only some attributes defined
4227*da0073e9SAndroid Build Coastguard Worker        # does not throw an error
4228*da0073e9SAndroid Build Coastguard Worker        m.cuda()
4229*da0073e9SAndroid Build Coastguard Worker        # recompute the weight and make sure that module can be used
4230*da0073e9SAndroid Build Coastguard Worker        m.weight_hh_l0 = weight_orig.cuda()
4231*da0073e9SAndroid Build Coastguard Worker        inp = inp.cuda()
4232*da0073e9SAndroid Build Coastguard Worker        # otherwise, subsequent warnings will be hidden, and further tests rely on them
4233*da0073e9SAndroid Build Coastguard Worker        warnings.simplefilter("always")
4234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(inp)[0].cpu(), out_expected[0])
4235*da0073e9SAndroid Build Coastguard Worker
4236*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4237*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4238*da0073e9SAndroid Build Coastguard Worker    def test_RNN_dropout(self):
4239*da0073e9SAndroid Build Coastguard Worker        # checking the assumption that cuDNN sticks dropout in between
4240*da0073e9SAndroid Build Coastguard Worker        # RNN layers
4241*da0073e9SAndroid Build Coastguard Worker        for p in (0, 0.276, 0.731, 1):
4242*da0073e9SAndroid Build Coastguard Worker            for train in (True, False):
4243*da0073e9SAndroid Build Coastguard Worker                for cuda in (True, False):
4244*da0073e9SAndroid Build Coastguard Worker                    rnn = nn.RNN(10, 1000, 2, bias=False, dropout=p, nonlinearity='relu')
4245*da0073e9SAndroid Build Coastguard Worker                    if cuda:
4246*da0073e9SAndroid Build Coastguard Worker                        rnn.cuda()
4247*da0073e9SAndroid Build Coastguard Worker
4248*da0073e9SAndroid Build Coastguard Worker                    if train:
4249*da0073e9SAndroid Build Coastguard Worker                        rnn.train()
4250*da0073e9SAndroid Build Coastguard Worker                    else:
4251*da0073e9SAndroid Build Coastguard Worker                        rnn.eval()
4252*da0073e9SAndroid Build Coastguard Worker                    rnn.weight_ih_l0.data.fill_(1)
4253*da0073e9SAndroid Build Coastguard Worker                    rnn.weight_hh_l0.data.fill_(1)
4254*da0073e9SAndroid Build Coastguard Worker                    rnn.weight_ih_l1.data.fill_(1)
4255*da0073e9SAndroid Build Coastguard Worker                    rnn.weight_hh_l1.data.fill_(1)
4256*da0073e9SAndroid Build Coastguard Worker                    input = torch.ones(1, 1, 10)
4257*da0073e9SAndroid Build Coastguard Worker                    hx = torch.zeros(2, 1, 1000)
4258*da0073e9SAndroid Build Coastguard Worker                    if cuda:
4259*da0073e9SAndroid Build Coastguard Worker                        input = input.cuda()
4260*da0073e9SAndroid Build Coastguard Worker                        hx = hx.cuda()
4261*da0073e9SAndroid Build Coastguard Worker
4262*da0073e9SAndroid Build Coastguard Worker                    output, hy = rnn(input, hx)
4263*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output.data.min(), output.data.max())
4264*da0073e9SAndroid Build Coastguard Worker                    output_val = output.data[0][0][0]
4265*da0073e9SAndroid Build Coastguard Worker                    if p == 0 or not train:
4266*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(output_val, 10000)
4267*da0073e9SAndroid Build Coastguard Worker                    elif p == 1:
4268*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(output_val, 0)
4269*da0073e9SAndroid Build Coastguard Worker                    else:
4270*da0073e9SAndroid Build Coastguard Worker                        self.assertGreater(output_val, 8000)
4271*da0073e9SAndroid Build Coastguard Worker                        self.assertLess(output_val, 12000)
4272*da0073e9SAndroid Build Coastguard Worker                        denorm_mod = (output_val * (1 - p)) % 10
4273*da0073e9SAndroid Build Coastguard Worker                        self.assertLess(min(denorm_mod, 10 - denorm_mod), 1e-2)
4274*da0073e9SAndroid Build Coastguard Worker
4275*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(hy[0].data.min(), hy[0].data.max())
4276*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(hy[1].data.min(), hy[1].data.max())
4277*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(hy.data[0][0][0], 10)
4278*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(hy.data[1][0][0], output_val)
4279*da0073e9SAndroid Build Coastguard Worker
4280*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4281*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4282*da0073e9SAndroid Build Coastguard Worker    def test_error_RNN_seq_len_zero(self):
4283*da0073e9SAndroid Build Coastguard Worker        # checking error message when RNN has seq_len = 0
4284*da0073e9SAndroid Build Coastguard Worker        for module in (nn.RNN, nn.LSTM, nn.GRU):
4285*da0073e9SAndroid Build Coastguard Worker            for bidirectional in [True, False]:
4286*da0073e9SAndroid Build Coastguard Worker                for device in get_all_device_types():
4287*da0073e9SAndroid Build Coastguard Worker                    input = torch.ones(0, 10, 5)
4288*da0073e9SAndroid Build Coastguard Worker                    rnn = module(5, 6, bidirectional=bidirectional)
4289*da0073e9SAndroid Build Coastguard Worker                    if device == 'cuda':
4290*da0073e9SAndroid Build Coastguard Worker                        rnn.cuda()
4291*da0073e9SAndroid Build Coastguard Worker                        input = input.cuda()
4292*da0073e9SAndroid Build Coastguard Worker
4293*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, "Expected sequence length to be larger than 0 in RNN"):
4294*da0073e9SAndroid Build Coastguard Worker                        rnn(input)
4295*da0073e9SAndroid Build Coastguard Worker
4296*da0073e9SAndroid Build Coastguard Worker    def test_RNN_input_size_zero(self):
4297*da0073e9SAndroid Build Coastguard Worker        for module in (nn.RNN, nn.LSTM, nn.GRU):
4298*da0073e9SAndroid Build Coastguard Worker            for device in get_all_device_types():
4299*da0073e9SAndroid Build Coastguard Worker                input = torch.zeros((5, 0, 3))
4300*da0073e9SAndroid Build Coastguard Worker                rnn = module(input_size=3, hidden_size=4)
4301*da0073e9SAndroid Build Coastguard Worker                if device == 'cuda':
4302*da0073e9SAndroid Build Coastguard Worker                    rnn.cuda()
4303*da0073e9SAndroid Build Coastguard Worker                    input = input.cuda()
4304*da0073e9SAndroid Build Coastguard Worker                outs = rnn(input)
4305*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(outs[0].shape, torch.Size([5, 0, 4]))
4306*da0073e9SAndroid Build Coastguard Worker                # Check that backward does not cause a hard error
4307*da0073e9SAndroid Build Coastguard Worker                outs[0].sum().backward()
4308*da0073e9SAndroid Build Coastguard Worker
4309*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4310*da0073e9SAndroid Build Coastguard Worker    def test_RNN_dropout_state(self):
4311*da0073e9SAndroid Build Coastguard Worker        for p in (0, 0.1234):
4312*da0073e9SAndroid Build Coastguard Worker            for train in (True, False):
4313*da0073e9SAndroid Build Coastguard Worker                for cuda in (True, False):
4314*da0073e9SAndroid Build Coastguard Worker                    rnn = nn.RNN(100, 100, 2, bias=False, dropout=p, nonlinearity='relu')
4315*da0073e9SAndroid Build Coastguard Worker                    if cuda:
4316*da0073e9SAndroid Build Coastguard Worker                        rnn.cuda()
4317*da0073e9SAndroid Build Coastguard Worker
4318*da0073e9SAndroid Build Coastguard Worker                    if train:
4319*da0073e9SAndroid Build Coastguard Worker                        rnn.train()
4320*da0073e9SAndroid Build Coastguard Worker                    else:
4321*da0073e9SAndroid Build Coastguard Worker                        rnn.eval()
4322*da0073e9SAndroid Build Coastguard Worker                    input = torch.rand(1, 1, 100)
4323*da0073e9SAndroid Build Coastguard Worker                    hx = torch.rand(2, 1, 100)
4324*da0073e9SAndroid Build Coastguard Worker                    if cuda:
4325*da0073e9SAndroid Build Coastguard Worker                        input = input.cuda()
4326*da0073e9SAndroid Build Coastguard Worker                        hx = hx.cuda()
4327*da0073e9SAndroid Build Coastguard Worker
4328*da0073e9SAndroid Build Coastguard Worker                    output1, hy1 = rnn(input, hx)
4329*da0073e9SAndroid Build Coastguard Worker                    output2, hy2 = rnn(input, hx)
4330*da0073e9SAndroid Build Coastguard Worker
4331*da0073e9SAndroid Build Coastguard Worker                    buf = io.BytesIO()
4332*da0073e9SAndroid Build Coastguard Worker                    rnn_pickle = torch.save(rnn, buf)
4333*da0073e9SAndroid Build Coastguard Worker                    buf.seek(0)
4334*da0073e9SAndroid Build Coastguard Worker                    # weights_only=False as this is legacy code that saves the model
4335*da0073e9SAndroid Build Coastguard Worker                    rnn2 = torch.load(buf, weights_only=False)
4336*da0073e9SAndroid Build Coastguard Worker                    rnn2.flatten_parameters()
4337*da0073e9SAndroid Build Coastguard Worker                    output3, hy3 = rnn2(input, hx)
4338*da0073e9SAndroid Build Coastguard Worker
4339*da0073e9SAndroid Build Coastguard Worker                    if p == 0 or not train:
4340*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(output1, output2)
4341*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(output1, output3)
4342*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(hy1, hy2)
4343*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(hy1, hy3)
4344*da0073e9SAndroid Build Coastguard Worker                    else:
4345*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(output1, output2)
4346*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(output1, output3)
4347*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(hy1, hy2)
4348*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(hy1, hy3)
4349*da0073e9SAndroid Build Coastguard Worker
4350*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4351*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4352*da0073e9SAndroid Build Coastguard Worker    def test_RNN_change_dropout(self):
4353*da0073e9SAndroid Build Coastguard Worker        for train, cuda in product((True, False), repeat=2):
4354*da0073e9SAndroid Build Coastguard Worker            rnn = nn.RNN(100, 100, 2, dropout=0, nonlinearity='relu')
4355*da0073e9SAndroid Build Coastguard Worker            input = torch.rand(3, 2, 100)
4356*da0073e9SAndroid Build Coastguard Worker            if cuda:
4357*da0073e9SAndroid Build Coastguard Worker                input.data = input.data.cuda()
4358*da0073e9SAndroid Build Coastguard Worker                rnn.cuda()
4359*da0073e9SAndroid Build Coastguard Worker
4360*da0073e9SAndroid Build Coastguard Worker            if train:
4361*da0073e9SAndroid Build Coastguard Worker                rnn.train()
4362*da0073e9SAndroid Build Coastguard Worker            else:
4363*da0073e9SAndroid Build Coastguard Worker                rnn.eval()
4364*da0073e9SAndroid Build Coastguard Worker
4365*da0073e9SAndroid Build Coastguard Worker            prev_output = None
4366*da0073e9SAndroid Build Coastguard Worker            for p in (0, 0.5, 0, 0.7, 0.2, 1, 0.2, 0):
4367*da0073e9SAndroid Build Coastguard Worker                rnn.dropout = p
4368*da0073e9SAndroid Build Coastguard Worker                output1, hy1 = rnn(input)
4369*da0073e9SAndroid Build Coastguard Worker                output2, hy2 = rnn(input)
4370*da0073e9SAndroid Build Coastguard Worker
4371*da0073e9SAndroid Build Coastguard Worker                if p == 0 or p == 1 or not train:
4372*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output1, output2)
4373*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(hy1, hy2)
4374*da0073e9SAndroid Build Coastguard Worker                else:
4375*da0073e9SAndroid Build Coastguard Worker                    self.assertNotEqual(output1, output2)
4376*da0073e9SAndroid Build Coastguard Worker                    self.assertNotEqual(hy1, hy2)
4377*da0073e9SAndroid Build Coastguard Worker
4378*da0073e9SAndroid Build Coastguard Worker                if prev_output is not None:
4379*da0073e9SAndroid Build Coastguard Worker                    if not train:
4380*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(output1.data, prev_output)
4381*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(output2.data, prev_output)
4382*da0073e9SAndroid Build Coastguard Worker                    else:
4383*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(output1.data, prev_output)
4384*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(output2.data, prev_output)
4385*da0073e9SAndroid Build Coastguard Worker                prev_output = output1.data
4386*da0073e9SAndroid Build Coastguard Worker
4387*da0073e9SAndroid Build Coastguard Worker    def test_inplace_thnn(self):
4388*da0073e9SAndroid Build Coastguard Worker        modules = [nn.ReLU, nn.ELU, nn.SELU, nn.CELU, nn.RReLU]
4389*da0073e9SAndroid Build Coastguard Worker        for mod in modules:
4390*da0073e9SAndroid Build Coastguard Worker            r = mod(inplace=True)
4391*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(5, 5, requires_grad=True)
4392*da0073e9SAndroid Build Coastguard Worker            output = r(input + 0)
4393*da0073e9SAndroid Build Coastguard Worker            grad_output = torch.randn(5, 5)
4394*da0073e9SAndroid Build Coastguard Worker            grad_output_clone = grad_output.clone()
4395*da0073e9SAndroid Build Coastguard Worker            output.backward(grad_output)
4396*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_output, grad_output_clone)
4397*da0073e9SAndroid Build Coastguard Worker
4398*da0073e9SAndroid Build Coastguard Worker
4399*da0073e9SAndroid Build Coastguard Worker    def test_pixel_shuffle_unshuffle(self):
4400*da0073e9SAndroid Build Coastguard Worker        def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
4401*da0073e9SAndroid Build Coastguard Worker                                                 upscale_factor=None):
4402*da0073e9SAndroid Build Coastguard Worker            # Function to imperatively ensure pixels are shuffled to the correct locations.
4403*da0073e9SAndroid Build Coastguard Worker            # Used to validate the batch operations in pixel_shuffle.
4404*da0073e9SAndroid Build Coastguard Worker            def _verify_pixel_shuffle(input, output, upscale_factor):
4405*da0073e9SAndroid Build Coastguard Worker                for c in range(output.size(-3)):
4406*da0073e9SAndroid Build Coastguard Worker                    for h in range(output.size(-2)):
4407*da0073e9SAndroid Build Coastguard Worker                        for w in range(output.size(-1)):
4408*da0073e9SAndroid Build Coastguard Worker                            height_idx = h // upscale_factor
4409*da0073e9SAndroid Build Coastguard Worker                            weight_idx = w // upscale_factor
4410*da0073e9SAndroid Build Coastguard Worker                            channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
4411*da0073e9SAndroid Build Coastguard Worker                                          (c * upscale_factor ** 2)
4412*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
4413*da0073e9SAndroid Build Coastguard Worker
4414*da0073e9SAndroid Build Coastguard Worker            upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
4415*da0073e9SAndroid Build Coastguard Worker            # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
4416*da0073e9SAndroid Build Coastguard Worker            channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
4417*da0073e9SAndroid Build Coastguard Worker            height = random.randint(5, 10)
4418*da0073e9SAndroid Build Coastguard Worker            width = random.randint(5, 10)
4419*da0073e9SAndroid Build Coastguard Worker
4420*da0073e9SAndroid Build Coastguard Worker            if num_input_dims == 1:
4421*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(channels, requires_grad=True)
4422*da0073e9SAndroid Build Coastguard Worker            elif num_input_dims == 2:
4423*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(height, width, requires_grad=True)
4424*da0073e9SAndroid Build Coastguard Worker            else:
4425*da0073e9SAndroid Build Coastguard Worker                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
4426*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
4427*da0073e9SAndroid Build Coastguard Worker            ps = nn.PixelShuffle(upscale_factor)
4428*da0073e9SAndroid Build Coastguard Worker            pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
4429*da0073e9SAndroid Build Coastguard Worker
4430*da0073e9SAndroid Build Coastguard Worker            if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
4431*da0073e9SAndroid Build Coastguard Worker                output = ps(input)
4432*da0073e9SAndroid Build Coastguard Worker                _verify_pixel_shuffle(input, output, upscale_factor)
4433*da0073e9SAndroid Build Coastguard Worker                output.backward(output.data)
4434*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input.data, input.grad.data)
4435*da0073e9SAndroid Build Coastguard Worker
4436*da0073e9SAndroid Build Coastguard Worker                # Ensure unshuffle properly inverts shuffle.
4437*da0073e9SAndroid Build Coastguard Worker                unshuffle_output = pus(output)
4438*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input, unshuffle_output)
4439*da0073e9SAndroid Build Coastguard Worker            else:
4440*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: ps(input))
4441*da0073e9SAndroid Build Coastguard Worker
4442*da0073e9SAndroid Build Coastguard Worker        def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
4443*da0073e9SAndroid Build Coastguard Worker                                                    downscale_factor=None):
4444*da0073e9SAndroid Build Coastguard Worker            downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
4445*da0073e9SAndroid Build Coastguard Worker            channels = random.randint(1, 4)
4446*da0073e9SAndroid Build Coastguard Worker            # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
4447*da0073e9SAndroid Build Coastguard Worker            height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
4448*da0073e9SAndroid Build Coastguard Worker            # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
4449*da0073e9SAndroid Build Coastguard Worker            width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
4450*da0073e9SAndroid Build Coastguard Worker
4451*da0073e9SAndroid Build Coastguard Worker            if num_input_dims == 1:
4452*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(channels, requires_grad=True)
4453*da0073e9SAndroid Build Coastguard Worker            elif num_input_dims == 2:
4454*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(height, width, requires_grad=True)
4455*da0073e9SAndroid Build Coastguard Worker            else:
4456*da0073e9SAndroid Build Coastguard Worker                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
4457*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True)
4458*da0073e9SAndroid Build Coastguard Worker
4459*da0073e9SAndroid Build Coastguard Worker            pus = nn.PixelUnshuffle(downscale_factor)
4460*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: pus(input))
4461*da0073e9SAndroid Build Coastguard Worker
4462*da0073e9SAndroid Build Coastguard Worker        def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
4463*da0073e9SAndroid Build Coastguard Worker            # For 1D - 2D, this is an error case.
4464*da0073e9SAndroid Build Coastguard Worker            # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
4465*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims)
4466*da0073e9SAndroid Build Coastguard Worker
4467*da0073e9SAndroid Build Coastguard Worker            # Error cases for pixel_shuffle.
4468*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, valid_channels_dim=False)
4469*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=0)
4470*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_helper(num_input_dims=num_input_dims, upscale_factor=-2)
4471*da0073e9SAndroid Build Coastguard Worker
4472*da0073e9SAndroid Build Coastguard Worker            # Error cases for pixel_unshuffle.
4473*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
4474*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
4475*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
4476*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
4477*da0073e9SAndroid Build Coastguard Worker
4478*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_1D():
4479*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
4480*da0073e9SAndroid Build Coastguard Worker
4481*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_2D():
4482*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
4483*da0073e9SAndroid Build Coastguard Worker
4484*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_3D():
4485*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
4486*da0073e9SAndroid Build Coastguard Worker
4487*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_4D():
4488*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
4489*da0073e9SAndroid Build Coastguard Worker
4490*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_5D():
4491*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
4492*da0073e9SAndroid Build Coastguard Worker
4493*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_1D()
4494*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_2D()
4495*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_3D()
4496*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_4D()
4497*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_5D()
4498*da0073e9SAndroid Build Coastguard Worker
4499*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
4500*da0073e9SAndroid Build Coastguard Worker    def test_pixel_shuffle_nhwc_cpu(self):
4501*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 18, 4, 4, device='cpu')
4502*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
4503*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(3, 18, 4, 4, device='cpu')
4504*da0073e9SAndroid Build Coastguard Worker        ps = torch.nn.PixelShuffle(3)
4505*da0073e9SAndroid Build Coastguard Worker        pus = torch.nn.PixelUnshuffle(3)
4506*da0073e9SAndroid Build Coastguard Worker
4507*da0073e9SAndroid Build Coastguard Worker        ref_input = input.detach().clone().contiguous().requires_grad_(True)
4508*da0073e9SAndroid Build Coastguard Worker        ref_grad = grad.detach().clone().contiguous()
4509*da0073e9SAndroid Build Coastguard Worker        ref_ps = torch.nn.PixelShuffle(3)
4510*da0073e9SAndroid Build Coastguard Worker        ref_pus = torch.nn.PixelUnshuffle(3)
4511*da0073e9SAndroid Build Coastguard Worker
4512*da0073e9SAndroid Build Coastguard Worker        out = pus(ps(input))
4513*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
4514*da0073e9SAndroid Build Coastguard Worker        ref_out = ref_pus(ref_ps(ref_input))
4515*da0073e9SAndroid Build Coastguard Worker        ref_out.backward(ref_grad)
4516*da0073e9SAndroid Build Coastguard Worker
4517*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4518*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref_out.is_contiguous())
4519*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
4520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, ref_input.grad)
4521*da0073e9SAndroid Build Coastguard Worker
4522*da0073e9SAndroid Build Coastguard Worker    # These tests should be OpInfo'd
4523*da0073e9SAndroid Build Coastguard Worker    def test_elu_inplace_on_view(self):
4524*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True, dtype=torch.double)
4525*da0073e9SAndroid Build Coastguard Worker
4526*da0073e9SAndroid Build Coastguard Worker        def func(root):
4527*da0073e9SAndroid Build Coastguard Worker            x = root.clone()
4528*da0073e9SAndroid Build Coastguard Worker            view = x.narrow(0, 1, 2)
4529*da0073e9SAndroid Build Coastguard Worker            res = F.elu(view, inplace=True)
4530*da0073e9SAndroid Build Coastguard Worker            self.assertIs(res, view)
4531*da0073e9SAndroid Build Coastguard Worker            return x
4532*da0073e9SAndroid Build Coastguard Worker
4533*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [v])
4534*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, [v])
4535*da0073e9SAndroid Build Coastguard Worker
4536*da0073e9SAndroid Build Coastguard Worker    def test_elu_inplace_gradgrad(self):
4537*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(8, requires_grad=True, dtype=torch.double)
4538*da0073e9SAndroid Build Coastguard Worker
4539*da0073e9SAndroid Build Coastguard Worker        def func(root):
4540*da0073e9SAndroid Build Coastguard Worker            x = root.clone()
4541*da0073e9SAndroid Build Coastguard Worker            return F.elu(x, inplace=True)
4542*da0073e9SAndroid Build Coastguard Worker
4543*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [v])
4544*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, [v])
4545*da0073e9SAndroid Build Coastguard Worker
4546*da0073e9SAndroid Build Coastguard Worker    def test_relu_inplace_on_view(self):
4547*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True, dtype=torch.double)
4548*da0073e9SAndroid Build Coastguard Worker
4549*da0073e9SAndroid Build Coastguard Worker        def func(root):
4550*da0073e9SAndroid Build Coastguard Worker            x = root.clone()
4551*da0073e9SAndroid Build Coastguard Worker            view = x.narrow(0, 1, 2)
4552*da0073e9SAndroid Build Coastguard Worker            res = F.relu(view, inplace=True)
4553*da0073e9SAndroid Build Coastguard Worker            self.assertIs(res, view)
4554*da0073e9SAndroid Build Coastguard Worker            return x
4555*da0073e9SAndroid Build Coastguard Worker
4556*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [v])
4557*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, [v])
4558*da0073e9SAndroid Build Coastguard Worker
4559*da0073e9SAndroid Build Coastguard Worker    def test_PReLU_backward_requires_grad_false(self):
4560*da0073e9SAndroid Build Coastguard Worker        devices = ['cpu']
4561*da0073e9SAndroid Build Coastguard Worker        devices += ['cuda'] if TEST_CUDA else []
4562*da0073e9SAndroid Build Coastguard Worker        for d in devices:
4563*da0073e9SAndroid Build Coastguard Worker            m = nn.PReLU().to(d)
4564*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 3, 4, 5, device=d, requires_grad=False)
4565*da0073e9SAndroid Build Coastguard Worker            y = m(x)
4566*da0073e9SAndroid Build Coastguard Worker            y.mean().backward()
4567*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, None)
4568*da0073e9SAndroid Build Coastguard Worker
4569*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_always_nonnegative(self):
4570*da0073e9SAndroid Build Coastguard Worker        target = torch.ones(5)
4571*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(5)
4572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4573*da0073e9SAndroid Build Coastguard Worker
4574*da0073e9SAndroid Build Coastguard Worker        target = torch.zeros(5)
4575*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros(5)
4576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4577*da0073e9SAndroid Build Coastguard Worker
4578*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_raises_if_target_and_input_are_different_size(self):
4579*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(5)
4580*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(5, 1)
4581*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
4582*da0073e9SAndroid Build Coastguard Worker            nn.BCEWithLogitsLoss()(input, target)
4583*da0073e9SAndroid Build Coastguard Worker
4584*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(5, 1)
4585*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(5)
4586*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
4587*da0073e9SAndroid Build Coastguard Worker            nn.BCEWithLogitsLoss()(input, target)
4588*da0073e9SAndroid Build Coastguard Worker
4589*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss(self):
4590*da0073e9SAndroid Build Coastguard Worker        sigmoid = nn.Sigmoid()
4591*da0073e9SAndroid Build Coastguard Worker
4592*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(64, 4)
4593*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(64, 4) - 0.5
4594*da0073e9SAndroid Build Coastguard Worker
4595*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target))
4596*da0073e9SAndroid Build Coastguard Worker
4597*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(4)
4598*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target))
4599*da0073e9SAndroid Build Coastguard Worker
4600*da0073e9SAndroid Build Coastguard Worker        target = torch.zeros(4, 1, dtype=torch.float)
4601*da0073e9SAndroid Build Coastguard Worker        output = torch.empty(4, 1, dtype=torch.float).fill_(-100)
4602*da0073e9SAndroid Build Coastguard Worker
4603*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.BCEWithLogitsLoss()(output, target), nn.BCELoss()(sigmoid(output), target))
4604*da0073e9SAndroid Build Coastguard Worker
4605*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.BCEWithLogitsLoss(reduction='none')(output, target),
4606*da0073e9SAndroid Build Coastguard Worker                         nn.BCELoss(reduction='none')(sigmoid(output), target))
4607*da0073e9SAndroid Build Coastguard Worker
4608*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(1, dtype=torch.float)
4609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.BCEWithLogitsLoss(weight)(output, target), nn.BCELoss(weight)(sigmoid(output), target))
4610*da0073e9SAndroid Build Coastguard Worker
4611*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_input_range(self):
4612*da0073e9SAndroid Build Coastguard Worker        bceloss = nn.BCELoss()
4613*da0073e9SAndroid Build Coastguard Worker
4614*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(25, 25)
4615*da0073e9SAndroid Build Coastguard Worker        output_valid = torch.rand(25, 25)
4616*da0073e9SAndroid Build Coastguard Worker        output_too_negative = output_valid - 1.0
4617*da0073e9SAndroid Build Coastguard Worker        output_too_positive = output_valid + 1.0
4618*da0073e9SAndroid Build Coastguard Worker
4619*da0073e9SAndroid Build Coastguard Worker        loss_valid = bceloss(output_valid, target)
4620*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'between 0 and 1'):
4621*da0073e9SAndroid Build Coastguard Worker            loss_too_negative = bceloss(output_too_negative, target)
4622*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'between 0 and 1'):
4623*da0073e9SAndroid Build Coastguard Worker            loss_too_positive = bceloss(output_too_positive, target)
4624*da0073e9SAndroid Build Coastguard Worker
4625*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_size_mismatch(self):
4626*da0073e9SAndroid Build Coastguard Worker        bceloss = nn.BCELoss()
4627*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(25)
4628*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(25, 1)
4629*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, r'Using a target size \('):
4630*da0073e9SAndroid Build Coastguard Worker            bceloss(a, b)
4631*da0073e9SAndroid Build Coastguard Worker
4632*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
4633*da0073e9SAndroid Build Coastguard Worker        x_size = 1024
4634*da0073e9SAndroid Build Coastguard Worker        y_size = 256
4635*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(x_size, y_size)
4636*da0073e9SAndroid Build Coastguard Worker
4637*da0073e9SAndroid Build Coastguard Worker        for reduction in ['none', 'mean', 'sum']:
4638*da0073e9SAndroid Build Coastguard Worker            output_sig = torch.rand(x_size, y_size) - 0.5
4639*da0073e9SAndroid Build Coastguard Worker            output_logits = output_sig.clone().detach()
4640*da0073e9SAndroid Build Coastguard Worker
4641*da0073e9SAndroid Build Coastguard Worker            output_sig.requires_grad = True
4642*da0073e9SAndroid Build Coastguard Worker            output_logits.requires_grad = True
4643*da0073e9SAndroid Build Coastguard Worker            weight = torch.rand(y_size)
4644*da0073e9SAndroid Build Coastguard Worker
4645*da0073e9SAndroid Build Coastguard Worker            loss_sig = nn.BCELoss(weight, reduction=reduction)(
4646*da0073e9SAndroid Build Coastguard Worker                torch.sigmoid(output_sig), target
4647*da0073e9SAndroid Build Coastguard Worker            )
4648*da0073e9SAndroid Build Coastguard Worker            loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
4649*da0073e9SAndroid Build Coastguard Worker                output_logits, target
4650*da0073e9SAndroid Build Coastguard Worker            )
4651*da0073e9SAndroid Build Coastguard Worker
4652*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(loss_logits, loss_sig)
4653*da0073e9SAndroid Build Coastguard Worker
4654*da0073e9SAndroid Build Coastguard Worker            if reduction == 'none':
4655*da0073e9SAndroid Build Coastguard Worker                grad = torch.rand(x_size, y_size)
4656*da0073e9SAndroid Build Coastguard Worker                loss_sig.backward(grad)
4657*da0073e9SAndroid Build Coastguard Worker                loss_logits.backward(grad)
4658*da0073e9SAndroid Build Coastguard Worker            else:
4659*da0073e9SAndroid Build Coastguard Worker                loss_sig.backward()
4660*da0073e9SAndroid Build Coastguard Worker                loss_logits.backward()
4661*da0073e9SAndroid Build Coastguard Worker
4662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_sig.grad, output_logits.grad)
4663*da0073e9SAndroid Build Coastguard Worker
4664*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_has_correct_forward_grad(self):
4665*da0073e9SAndroid Build Coastguard Worker        output = torch.randn(3, 5, requires_grad=True, dtype=torch.double)
4666*da0073e9SAndroid Build Coastguard Worker        target = torch.randn(3, 5, dtype=torch.double)
4667*da0073e9SAndroid Build Coastguard Worker        for reduction in ('sum', 'mean', 'none'):
4668*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda self, target: nn.BCEWithLogitsLoss(reduction=reduction)(self, target),
4669*da0073e9SAndroid Build Coastguard Worker                      (output, target), check_forward_ad=True)
4670*da0073e9SAndroid Build Coastguard Worker
4671*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_has_correct_grad_at_zero(self):
4672*da0073e9SAndroid Build Coastguard Worker        output = torch.zeros(3, 1, requires_grad=True)
4673*da0073e9SAndroid Build Coastguard Worker        target = torch.zeros(3, 1)
4674*da0073e9SAndroid Build Coastguard Worker        nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
4675*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.empty(3, 1).fill_(0.5)
4676*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.grad, expected_grad)
4677*da0073e9SAndroid Build Coastguard Worker
4678*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_broadcasts_weights(self):
4679*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(16, 4)
4680*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(16, 4) - 0.5
4681*da0073e9SAndroid Build Coastguard Worker
4682*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(4)
4683*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4684*da0073e9SAndroid Build Coastguard Worker
4685*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4686*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4687*da0073e9SAndroid Build Coastguard Worker
4688*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4689*da0073e9SAndroid Build Coastguard Worker
4690*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(16, 1)
4691*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4692*da0073e9SAndroid Build Coastguard Worker
4693*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4694*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4695*da0073e9SAndroid Build Coastguard Worker
4696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4697*da0073e9SAndroid Build Coastguard Worker
4698*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
4699*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(64, 4)
4700*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(64, 4) - 0.5
4701*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.ones(64, 4)
4702*da0073e9SAndroid Build Coastguard Worker
4703*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
4704*da0073e9SAndroid Build Coastguard Worker                         nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
4705*da0073e9SAndroid Build Coastguard Worker
4706*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_broadcasts_pos_weights(self):
4707*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(64, 4)
4708*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(64, 4) - 0.5
4709*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.rand(4)
4710*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4711*da0073e9SAndroid Build Coastguard Worker
4712*da0073e9SAndroid Build Coastguard Worker        pos_weight1 = pos_weight.expand(1, 4)
4713*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
4714*da0073e9SAndroid Build Coastguard Worker
4715*da0073e9SAndroid Build Coastguard Worker        pos_weight2 = pos_weight.expand(64, 4)
4716*da0073e9SAndroid Build Coastguard Worker        out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
4717*da0073e9SAndroid Build Coastguard Worker
4718*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out3)
4720*da0073e9SAndroid Build Coastguard Worker
4721*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
4722*da0073e9SAndroid Build Coastguard Worker        output = torch.zeros(3, 1, requires_grad=True)
4723*da0073e9SAndroid Build Coastguard Worker        target = torch.zeros(3, 1)
4724*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.ones(3, 1)
4725*da0073e9SAndroid Build Coastguard Worker        nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
4726*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.empty(3, 1).fill_(0.5)
4727*da0073e9SAndroid Build Coastguard Worker        grad = output.grad
4728*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad, expected_grad)
4729*da0073e9SAndroid Build Coastguard Worker
4730*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_stability(self):
4731*da0073e9SAndroid Build Coastguard Worker        output = torch.tensor([0., -120.])
4732*da0073e9SAndroid Build Coastguard Worker        target = torch.tensor([0., 1.])
4733*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.tensor([1., 1.])
4734*da0073e9SAndroid Build Coastguard Worker
4735*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss()(output, target)
4736*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isfinite(out1).all().item())
4737*da0073e9SAndroid Build Coastguard Worker
4738*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4739*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isfinite(out2).all().item())
4740*da0073e9SAndroid Build Coastguard Worker
4741*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_broadcasts_weights(self):
4742*da0073e9SAndroid Build Coastguard Worker        sigmoid = nn.Sigmoid()
4743*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(16, 4)
4744*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(16, 4) - 0.5
4745*da0073e9SAndroid Build Coastguard Worker
4746*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(4)
4747*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4748*da0073e9SAndroid Build Coastguard Worker
4749*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4750*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4751*da0073e9SAndroid Build Coastguard Worker
4752*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4753*da0073e9SAndroid Build Coastguard Worker
4754*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(16, 1)
4755*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4756*da0073e9SAndroid Build Coastguard Worker
4757*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4758*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4759*da0073e9SAndroid Build Coastguard Worker
4760*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4761*da0073e9SAndroid Build Coastguard Worker
4762*da0073e9SAndroid Build Coastguard Worker    def test_hardtanh_inplace_gradgrad(self):
4763*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(8, requires_grad=True, dtype=torch.double)
4764*da0073e9SAndroid Build Coastguard Worker
4765*da0073e9SAndroid Build Coastguard Worker        def func(root):
4766*da0073e9SAndroid Build Coastguard Worker            x = root.clone()
4767*da0073e9SAndroid Build Coastguard Worker            return F.hardtanh(x, inplace=True)
4768*da0073e9SAndroid Build Coastguard Worker
4769*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [v])
4770*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, [v])
4771*da0073e9SAndroid Build Coastguard Worker
4772*da0073e9SAndroid Build Coastguard Worker    # test hardtanh backward for large tensor
4773*da0073e9SAndroid Build Coastguard Worker    def test_hardtanh_backward(self):
4774*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 10000, requires_grad=True)
4775*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(128, 10000)
4776*da0073e9SAndroid Build Coastguard Worker        z = torch.zeros(128, 10000)
4777*da0073e9SAndroid Build Coastguard Worker        y = F.hardtanh(x)
4778*da0073e9SAndroid Build Coastguard Worker        y.backward(grad)
4779*da0073e9SAndroid Build Coastguard Worker        # ref backward path for hardtanh
4780*da0073e9SAndroid Build Coastguard Worker        mask = (x > -1) & (x < 1)
4781*da0073e9SAndroid Build Coastguard Worker        x_grad_ref = torch.where(mask, grad, z)
4782*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad_ref)
4783*da0073e9SAndroid Build Coastguard Worker
4784*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_nhwc_cpu(self):
4785*da0073e9SAndroid Build Coastguard Worker        def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last, precision=None):
4786*da0073e9SAndroid Build Coastguard Worker            channels = size[1]
4787*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(size, dtype=dtype, device='cpu', requires_grad=True)
4788*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=format).to(dtype)
4789*da0073e9SAndroid Build Coastguard Worker            input.retain_grad()
4790*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn(size, dtype=dtype, device='cpu')
4791*da0073e9SAndroid Build Coastguard Worker            grad = grad.contiguous(memory_format=format)
4792*da0073e9SAndroid Build Coastguard Worker            bn = mod(channels).cpu().to(dtype)
4793*da0073e9SAndroid Build Coastguard Worker            bn.weight.data.uniform_()
4794*da0073e9SAndroid Build Coastguard Worker            bn.bias.data.uniform_()
4795*da0073e9SAndroid Build Coastguard Worker
4796*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
4797*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
4798*da0073e9SAndroid Build Coastguard Worker            ref_bn = mod(channels).cpu().to(dtype)
4799*da0073e9SAndroid Build Coastguard Worker            ref_bn.load_state_dict(bn.state_dict())
4800*da0073e9SAndroid Build Coastguard Worker
4801*da0073e9SAndroid Build Coastguard Worker            if mixed_dtype:
4802*da0073e9SAndroid Build Coastguard Worker                bn.float()
4803*da0073e9SAndroid Build Coastguard Worker                ref_bn.float()
4804*da0073e9SAndroid Build Coastguard Worker
4805*da0073e9SAndroid Build Coastguard Worker            out = bn(input)
4806*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
4807*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_bn(ref_input)
4808*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
4809*da0073e9SAndroid Build Coastguard Worker
4810*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=format))
4811*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
4812*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
4813*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bn.weight.grad, ref_bn.weight.grad, atol=precision, rtol=precision)
4814*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bn.bias.grad, ref_bn.bias.grad)
4815*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
4816*da0073e9SAndroid Build Coastguard Worker
4817*da0073e9SAndroid Build Coastguard Worker        # test NC11 and N1HW; test mixed dtype
4818*da0073e9SAndroid Build Coastguard Worker        for shape in [(4, 8, 10, 10), (4, 1, 9, 9), (4, 9, 1, 1)]:
4819*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.float, torch.bfloat16, torch.float16]:
4820*da0073e9SAndroid Build Coastguard Worker                for mixed_dtype in [False, True]:
4821*da0073e9SAndroid Build Coastguard Worker                    if dtype == torch.float:
4822*da0073e9SAndroid Build Coastguard Worker                        mixed_dtype = False
4823*da0073e9SAndroid Build Coastguard Worker                    helper(self, nn.BatchNorm2d, shape, dtype, mixed_dtype, torch.channels_last)
4824*da0073e9SAndroid Build Coastguard Worker
4825*da0073e9SAndroid Build Coastguard Worker        precisons = {torch.float: 1e-4, torch.bfloat16: 1e-4, torch.float16: None}
4826*da0073e9SAndroid Build Coastguard Worker        for shape in [(4, 8, 2, 10, 10), (4, 1, 2, 9, 9), (4, 9, 1, 1, 1)]:
4827*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.float, torch.bfloat16, torch.float16]:
4828*da0073e9SAndroid Build Coastguard Worker                for mixed_dtype in [False, True]:
4829*da0073e9SAndroid Build Coastguard Worker                    if dtype == torch.float:
4830*da0073e9SAndroid Build Coastguard Worker                        mixed_dtype = False
4831*da0073e9SAndroid Build Coastguard Worker                    helper(self, nn.BatchNorm3d, shape, dtype, mixed_dtype, torch.channels_last_3d, precisons[dtype])
4832*da0073e9SAndroid Build Coastguard Worker
4833*da0073e9SAndroid Build Coastguard Worker    @parametrize_test(
4834*da0073e9SAndroid Build Coastguard Worker        'bn_module',
4835*da0073e9SAndroid Build Coastguard Worker        [
4836*da0073e9SAndroid Build Coastguard Worker            subtest(torch.nn.BatchNorm2d, name="BatchNorm2d"),
4837*da0073e9SAndroid Build Coastguard Worker            subtest(torch.nn.SyncBatchNorm, name="SyncBatchNorm"),
4838*da0073e9SAndroid Build Coastguard Worker        ],
4839*da0073e9SAndroid Build Coastguard Worker    )
4840*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_non_contig_cpu(self, bn_module):
4841*da0073e9SAndroid Build Coastguard Worker        def helper(self, dtype):
4842*da0073e9SAndroid Build Coastguard Worker            input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu()
4843*da0073e9SAndroid Build Coastguard Worker            input = input.permute(0, 2, 1, 3)
4844*da0073e9SAndroid Build Coastguard Worker
4845*da0073e9SAndroid Build Coastguard Worker            bn = bn_module(2).cpu().float().eval()
4846*da0073e9SAndroid Build Coastguard Worker            bn.weight.data.uniform_()
4847*da0073e9SAndroid Build Coastguard Worker            bn.bias.data.uniform_()
4848*da0073e9SAndroid Build Coastguard Worker
4849*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous()
4850*da0073e9SAndroid Build Coastguard Worker            ref_bn = nn.BatchNorm2d(2).cpu().float().eval()
4851*da0073e9SAndroid Build Coastguard Worker            ref_bn.load_state_dict(bn.state_dict())
4852*da0073e9SAndroid Build Coastguard Worker
4853*da0073e9SAndroid Build Coastguard Worker            out = bn(input)
4854*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_bn(ref_input)
4855*da0073e9SAndroid Build Coastguard Worker
4856*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4857*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
4858*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
4859*da0073e9SAndroid Build Coastguard Worker
4860*da0073e9SAndroid Build Coastguard Worker            input_bf = torch.arange(24, dtype=dtype).reshape(1, 3, 2, 4)
4861*da0073e9SAndroid Build Coastguard Worker            input_bf = input_bf.permute(0, 2, 1, 3)
4862*da0073e9SAndroid Build Coastguard Worker            input_f = input_bf.float()
4863*da0073e9SAndroid Build Coastguard Worker            bn_mix = bn_module(2).float().eval()
4864*da0073e9SAndroid Build Coastguard Worker            ref_bn_f = deepcopy(bn_mix)
4865*da0073e9SAndroid Build Coastguard Worker            out_bf = bn_mix(input_bf)
4866*da0073e9SAndroid Build Coastguard Worker            ref_out_bf = ref_bn_f(input_f)
4867*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out_bf, out_bf.float(), atol=0.05, rtol=0.05)
4868*da0073e9SAndroid Build Coastguard Worker
4869*da0073e9SAndroid Build Coastguard Worker        helper(self, torch.bfloat16)
4870*da0073e9SAndroid Build Coastguard Worker        helper(self, torch.float16)
4871*da0073e9SAndroid Build Coastguard Worker
4872*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4873*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4874*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_cudnn_nhwc(self):
4875*da0073e9SAndroid Build Coastguard Worker        def run_test(input, grad_output):
4876*da0073e9SAndroid Build Coastguard Worker            c = input.size(1)
4877*da0073e9SAndroid Build Coastguard Worker            mod = nn.BatchNorm2d(c).cuda().float()
4878*da0073e9SAndroid Build Coastguard Worker            mod.weight.data.uniform_()
4879*da0073e9SAndroid Build Coastguard Worker            mod.bias.data.uniform_()
4880*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
4881*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
4882*da0073e9SAndroid Build Coastguard Worker            ref_mod = nn.BatchNorm2d(c).cuda().float()
4883*da0073e9SAndroid Build Coastguard Worker            ref_mod.load_state_dict(mod.state_dict())
4884*da0073e9SAndroid Build Coastguard Worker            out = mod(input)
4885*da0073e9SAndroid Build Coastguard Worker            out.backward(grad_output)
4886*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_mod(ref_input)
4887*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
4888*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4889*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
4890*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
4891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
4892*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
4893*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
4894*da0073e9SAndroid Build Coastguard Worker
4895*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4896*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4897*da0073e9SAndroid Build Coastguard Worker
4898*da0073e9SAndroid Build Coastguard Worker        grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4899*da0073e9SAndroid Build Coastguard Worker        grad = grad.contiguous(memory_format=torch.channels_last)
4900*da0073e9SAndroid Build Coastguard Worker        run_test(input, grad)
4901*da0073e9SAndroid Build Coastguard Worker        # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
4902*da0073e9SAndroid Build Coastguard Worker        # not channels_last
4903*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4904*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4905*da0073e9SAndroid Build Coastguard Worker        grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4906*da0073e9SAndroid Build Coastguard Worker        grad = grad.permute(0, 2, 1, 3)
4907*da0073e9SAndroid Build Coastguard Worker        run_test(input, grad)
4908*da0073e9SAndroid Build Coastguard Worker
4909*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4910*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_cudnn_half(self):
4911*da0073e9SAndroid Build Coastguard Worker        # THNN
4912*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(1, 10, (2, 3, 2, 2), dtype=torch.half, device="cuda", requires_grad=True)
4913*da0073e9SAndroid Build Coastguard Worker        m = nn.BatchNorm2d(3).half().cuda()
4914*da0073e9SAndroid Build Coastguard Worker        thnn_output = m(input)
4915*da0073e9SAndroid Build Coastguard Worker        thnn_output.sum().backward()
4916*da0073e9SAndroid Build Coastguard Worker        thnn_input_grad = input.grad.data.clone()
4917*da0073e9SAndroid Build Coastguard Worker        self.assertEqualTypeString(thnn_output, input)
4918*da0073e9SAndroid Build Coastguard Worker        # cuDNN
4919*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDNN:
4920*da0073e9SAndroid Build Coastguard Worker            input.grad = None
4921*da0073e9SAndroid Build Coastguard Worker            m = m.float()
4922*da0073e9SAndroid Build Coastguard Worker            cudnn_output = m(input)
4923*da0073e9SAndroid Build Coastguard Worker            cudnn_output.sum().backward()
4924*da0073e9SAndroid Build Coastguard Worker            cudnn_input_grad = input.grad.data.clone()
4925*da0073e9SAndroid Build Coastguard Worker            self.assertEqualTypeString(cudnn_output, input)
4926*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cudnn_output, thnn_output)
4927*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cudnn_input_grad, thnn_input_grad, atol=1e-3, rtol=0)
4928*da0073e9SAndroid Build Coastguard Worker
4929*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4930*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_nonaffine_cuda_half_input(self):
4931*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(16, 3, 24, 24, dtype=torch.half, device="cuda")
4932*da0073e9SAndroid Build Coastguard Worker        m = nn.BatchNorm2d(3, affine=False).cuda().float()  # keep running stats in FP32
4933*da0073e9SAndroid Build Coastguard Worker        output = m(input)
4934*da0073e9SAndroid Build Coastguard Worker        self.assertEqualTypeString(output, input)
4935*da0073e9SAndroid Build Coastguard Worker        m.eval()
4936*da0073e9SAndroid Build Coastguard Worker        output = m(input)
4937*da0073e9SAndroid Build Coastguard Worker        self.assertEqualTypeString(output, input)
4938*da0073e9SAndroid Build Coastguard Worker
4939*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_raises_error_if_less_than_one_value_per_channel(self):
4940*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10)[None, :, None]
4941*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
4942*da0073e9SAndroid Build Coastguard Worker            torch.nn.BatchNorm1d(10)(x)
4943*da0073e9SAndroid Build Coastguard Worker
4944*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self):
4945*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 10)
4946*da0073e9SAndroid Build Coastguard Worker        running_var = torch.rand(10)
4947*da0073e9SAndroid Build Coastguard Worker        wrong_sizes = [9, 11]
4948*da0073e9SAndroid Build Coastguard Worker        for size in wrong_sizes:
4949*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
4950*da0073e9SAndroid Build Coastguard Worker                F.batch_norm(input, torch.rand(size), running_var)
4951*da0073e9SAndroid Build Coastguard Worker
4952*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_raises_error_if_running_var_is_not_same_size_as_input(self):
4953*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 10)
4954*da0073e9SAndroid Build Coastguard Worker        running_mean = torch.rand(10)
4955*da0073e9SAndroid Build Coastguard Worker        wrong_sizes = [9, 11]
4956*da0073e9SAndroid Build Coastguard Worker        for size in wrong_sizes:
4957*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
4958*da0073e9SAndroid Build Coastguard Worker                F.batch_norm(input, running_mean, torch.rand(size))
4959*da0073e9SAndroid Build Coastguard Worker
4960*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_raises_error_if_weight_is_not_same_size_as_input(self):
4961*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 10)
4962*da0073e9SAndroid Build Coastguard Worker        running_mean = torch.rand(10)
4963*da0073e9SAndroid Build Coastguard Worker        running_var = torch.rand(10)
4964*da0073e9SAndroid Build Coastguard Worker        wrong_sizes = [9, 11]
4965*da0073e9SAndroid Build Coastguard Worker        for size in wrong_sizes:
4966*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
4967*da0073e9SAndroid Build Coastguard Worker                F.batch_norm(input, running_mean, running_var, weight=Parameter(torch.rand(size)))
4968*da0073e9SAndroid Build Coastguard Worker
4969*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_raises_error_if_bias_is_not_same_size_as_input(self):
4970*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 10)
4971*da0073e9SAndroid Build Coastguard Worker        running_mean = torch.rand(10)
4972*da0073e9SAndroid Build Coastguard Worker        running_var = torch.rand(10)
4973*da0073e9SAndroid Build Coastguard Worker        wrong_sizes = [9, 11]
4974*da0073e9SAndroid Build Coastguard Worker        for size in wrong_sizes:
4975*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
4976*da0073e9SAndroid Build Coastguard Worker                F.batch_norm(input, running_mean, running_var, bias=Parameter(torch.rand(size)))
4977*da0073e9SAndroid Build Coastguard Worker
4978*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_raises_error_if_running_var_or_running_mean_have_forward_grad(self):
4979*da0073e9SAndroid Build Coastguard Worker        args = (
4980*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 2, 5),  # input
4981*da0073e9SAndroid Build Coastguard Worker            torch.randn(2),  # running_mean
4982*da0073e9SAndroid Build Coastguard Worker            torch.randn(2),  # running_var
4983*da0073e9SAndroid Build Coastguard Worker        )
4984*da0073e9SAndroid Build Coastguard Worker        kwargs = {'training': False, 'momentum': -1.2}
4985*da0073e9SAndroid Build Coastguard Worker        fn = partial(F.batch_norm, **kwargs)
4986*da0073e9SAndroid Build Coastguard Worker
4987*da0073e9SAndroid Build Coastguard Worker        for dual_indices in ((0,), (1,), (1, 2), (0, 1), (0, 1, 2),):
4988*da0073e9SAndroid Build Coastguard Worker            tangents = tuple(torch.rand_like(x) for x in args)
4989*da0073e9SAndroid Build Coastguard Worker
4990*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
4991*da0073e9SAndroid Build Coastguard Worker                duals = [fwAD.make_dual(primal, tangent) if i in dual_indices else primal
4992*da0073e9SAndroid Build Coastguard Worker                         for i, (primal, tangent) in enumerate(zip(args, tangents))]
4993*da0073e9SAndroid Build Coastguard Worker                msg = "batch_norm is not differentiable wrt running_mean and running_var"
4994*da0073e9SAndroid Build Coastguard Worker                # 0 needs to have forward grad because otherwise we won't even run batch_norm_jvp
4995*da0073e9SAndroid Build Coastguard Worker                if (1 in dual_indices or 2 in dual_indices) and 0 in dual_indices:
4996*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, msg):
4997*da0073e9SAndroid Build Coastguard Worker                        fn(*duals)
4998*da0073e9SAndroid Build Coastguard Worker                else:
4999*da0073e9SAndroid Build Coastguard Worker                    fn(*duals)
5000*da0073e9SAndroid Build Coastguard Worker
5001*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_buffer_update_when_stats_are_not_tracked(self):
5002*da0073e9SAndroid Build Coastguard Worker        input_size = (32, 4)
5003*da0073e9SAndroid Build Coastguard Worker        # Instantiate BN with buffers that are not None
5004*da0073e9SAndroid Build Coastguard Worker        bn = nn.BatchNorm1d(input_size[1], track_running_stats=True)
5005*da0073e9SAndroid Build Coastguard Worker        # Use buffers for normalization but don't update them
5006*da0073e9SAndroid Build Coastguard Worker        bn.track_running_stats = False
5007*da0073e9SAndroid Build Coastguard Worker        # Store initial values
5008*da0073e9SAndroid Build Coastguard Worker        num_batches = bn.num_batches_tracked.clone()
5009*da0073e9SAndroid Build Coastguard Worker        running_mean = bn.running_mean.clone()
5010*da0073e9SAndroid Build Coastguard Worker        running_var = bn.running_var.clone()
5011*da0073e9SAndroid Build Coastguard Worker        # Forward random tensor
5012*da0073e9SAndroid Build Coastguard Worker        _ = bn(torch.rand(input_size))
5013*da0073e9SAndroid Build Coastguard Worker        # Ensure none of the buffers has been updated
5014*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(num_batches, bn.num_batches_tracked))
5015*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(running_mean, bn.running_mean))
5016*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(running_var, bn.running_var))
5017*da0073e9SAndroid Build Coastguard Worker
5018*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
5019*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_nhwc_cuda(self):
5020*da0073e9SAndroid Build Coastguard Worker        for dtype in (torch.half, torch.float):
5021*da0073e9SAndroid Build Coastguard Worker            (N, C, H, W) = 2, 64, 50, 50
5022*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
5023*da0073e9SAndroid Build Coastguard Worker            model = model.eval().cuda().to(dtype)
5024*da0073e9SAndroid Build Coastguard Worker            inp1 = torch.randn(N, C, H, W, device=torch.device('cuda'), dtype=dtype)
5025*da0073e9SAndroid Build Coastguard Worker            inp2 = inp1.contiguous(memory_format=torch.channels_last)
5026*da0073e9SAndroid Build Coastguard Worker            out1 = model(inp1)
5027*da0073e9SAndroid Build Coastguard Worker            out2 = model(inp2)
5028*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(out1, out2))
5029*da0073e9SAndroid Build Coastguard Worker
5030*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_load_state_dict(self):
5031*da0073e9SAndroid Build Coastguard Worker        bn = torch.nn.BatchNorm2d(3)
5032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(0))
5033*da0073e9SAndroid Build Coastguard Worker
5034*da0073e9SAndroid Build Coastguard Worker        bn.num_batches_tracked = torch.tensor(10)
5035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))
5036*da0073e9SAndroid Build Coastguard Worker
5037*da0073e9SAndroid Build Coastguard Worker        empty_dict = OrderedDict()
5038*da0073e9SAndroid Build Coastguard Worker        bn.load_state_dict(empty_dict, strict=False)
5039*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))
5040*da0073e9SAndroid Build Coastguard Worker
5041*da0073e9SAndroid Build Coastguard Worker        # test that when `num_batches_tracked` is not in loaded state_dict,
5042*da0073e9SAndroid Build Coastguard Worker        # meta num_batches_tracked is still replaced with singleton 0 tensor
5043*da0073e9SAndroid Build Coastguard Worker        with torch.device('meta'):
5044*da0073e9SAndroid Build Coastguard Worker            meta_bn = torch.nn.BatchNorm2d(3)
5045*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(meta_bn.num_batches_tracked.device == torch.device('meta'))
5046*da0073e9SAndroid Build Coastguard Worker        meta_bn.load_state_dict(empty_dict, assign=True, strict=False)
5047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(meta_bn.state_dict()["num_batches_tracked"], torch.tensor(0))
5048*da0073e9SAndroid Build Coastguard Worker
5049*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_update_stats(self):
5050*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(0, 1)
5051*da0073e9SAndroid Build Coastguard Worker        running_mean = torch.rand(1)
5052*da0073e9SAndroid Build Coastguard Worker        running_var = torch.rand(1)
5053*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
5054*da0073e9SAndroid Build Coastguard Worker                                    re.escape("input tensor must have at least one element, but got input_sizes = [0, 1]")):
5055*da0073e9SAndroid Build Coastguard Worker            torch.batch_norm_update_stats(input=input, momentum=0.0, running_mean=running_mean, running_var=running_var)
5056*da0073e9SAndroid Build Coastguard Worker
5057*da0073e9SAndroid Build Coastguard Worker    def test_pairwise_distance(self):
5058*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
5059*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
5060*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
5061*da0073e9SAndroid Build Coastguard Worker
5062*da0073e9SAndroid Build Coastguard Worker    # TODO: Create an OpInfo for pdist
5063*da0073e9SAndroid Build Coastguard Worker    def test_pdist(self):
5064*da0073e9SAndroid Build Coastguard Worker        for device, trans in itertools.product(device_(), [False, True]):
5065*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True)
5066*da0073e9SAndroid Build Coastguard Worker            if trans:
5067*da0073e9SAndroid Build Coastguard Worker                inp = inp.transpose(0, 1)
5068*da0073e9SAndroid Build Coastguard Worker            for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
5069*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
5070*da0073e9SAndroid Build Coastguard Worker
5071*da0073e9SAndroid Build Coastguard Worker    def test_pdist_zeros(self):
5072*da0073e9SAndroid Build Coastguard Worker        """Test that grad is still valid when dist is 0"""
5073*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5074*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True).repeat([2, 1])
5075*da0073e9SAndroid Build Coastguard Worker            for p in [0, 1, 2, 0.5, 1.5, 2.5, float('inf')]:
5076*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(gradcheck(lambda x: F.pdist(x, p), (inp,)))
5077*da0073e9SAndroid Build Coastguard Worker
5078*da0073e9SAndroid Build Coastguard Worker    def test_pdist_empty_row(self):
5079*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5080*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(1, 3, dtype=torch.double, device=device, requires_grad=True)
5081*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(F.pdist, (inp,)))
5082*da0073e9SAndroid Build Coastguard Worker
5083*da0073e9SAndroid Build Coastguard Worker    def test_pdist_empty_col(self):
5084*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5085*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(4, 0, dtype=torch.double, device=device, requires_grad=True)
5086*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(F.pdist, (inp,)))
5087*da0073e9SAndroid Build Coastguard Worker
5088*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
5089*da0073e9SAndroid Build Coastguard Worker    def test_pdist_cpu_gradgrad_unimplemented(self):
5090*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(4, 5, requires_grad=True)
5091*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(F.pdist, (inp,))
5092*da0073e9SAndroid Build Coastguard Worker
5093*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
5094*da0073e9SAndroid Build Coastguard Worker    def test_pdist_cuda_gradgrad_unimplemented(self):
5095*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(4, 5, device='cuda', requires_grad=True)
5096*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(F.pdist, (inp,))
5097*da0073e9SAndroid Build Coastguard Worker
5098*da0073e9SAndroid Build Coastguard Worker    # Merge into OpInfo?
5099*da0073e9SAndroid Build Coastguard Worker    # test for backward in https://github.com/pytorch/pytorch/issues/15511
5100*da0073e9SAndroid Build Coastguard Worker    def test_pdist_large(self):
5101*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5102*da0073e9SAndroid Build Coastguard Worker            def func(x):
5103*da0073e9SAndroid Build Coastguard Worker                return torch.pdist(x, p=2)
5104*da0073e9SAndroid Build Coastguard Worker
5105*da0073e9SAndroid Build Coastguard Worker            # shape[0] should be able to be (roughly) arbitrarily large, but the kernel
5106*da0073e9SAndroid Build Coastguard Worker            # is currently limited to smaller sizes (see issue above); this is just testing
5107*da0073e9SAndroid Build Coastguard Worker            # a floor.
5108*da0073e9SAndroid Build Coastguard Worker            shape = (1000, 1)
5109*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device).requires_grad_()
5110*da0073e9SAndroid Build Coastguard Worker            output = torch.pdist(x, p=2)
5111*da0073e9SAndroid Build Coastguard Worker            # just run a single backward, as gradcheck/gradgradcheck is expensive here
5112*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
5113*da0073e9SAndroid Build Coastguard Worker
5114*da0073e9SAndroid Build Coastguard Worker    def test_cosine_embedding_loss_with_diff_type(self):
5115*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5116*da0073e9SAndroid Build Coastguard Worker            input1 = torch.tensor([[2, 3, 4], [6, 2, 4]], dtype=torch.double, device=device)
5117*da0073e9SAndroid Build Coastguard Worker            input2 = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
5118*da0073e9SAndroid Build Coastguard Worker            target = torch.tensor([1, -1], dtype=torch.int, device=device)
5119*da0073e9SAndroid Build Coastguard Worker            expected = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5120*da0073e9SAndroid Build Coastguard Worker            for dt1 in get_all_math_dtypes(device):
5121*da0073e9SAndroid Build Coastguard Worker                for dt2 in get_all_math_dtypes(device):
5122*da0073e9SAndroid Build Coastguard Worker                    for dt3 in get_all_math_dtypes(device):
5123*da0073e9SAndroid Build Coastguard Worker                        # dt3 is used as dtype for target = [1, -1], so let's skip unsigned type
5124*da0073e9SAndroid Build Coastguard Worker                        if dt3 == torch.uint8:
5125*da0073e9SAndroid Build Coastguard Worker                            continue
5126*da0073e9SAndroid Build Coastguard Worker                        if dt1.is_complex or dt2.is_complex or dt3.is_complex:
5127*da0073e9SAndroid Build Coastguard Worker                            continue
5128*da0073e9SAndroid Build Coastguard Worker                        input1 = input1.to(dt1)
5129*da0073e9SAndroid Build Coastguard Worker                        input2 = input2.to(dt2)
5130*da0073e9SAndroid Build Coastguard Worker                        target = target.to(dt3)
5131*da0073e9SAndroid Build Coastguard Worker                        result = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5132*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0)
5133*da0073e9SAndroid Build Coastguard Worker
5134*da0073e9SAndroid Build Coastguard Worker    def test_cosine_embedding_loss_error_on_diff_shapes(self):
5135*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5136*da0073e9SAndroid Build Coastguard Worker            input1 = torch.empty((0, 0), dtype=torch.double, device=device)
5137*da0073e9SAndroid Build Coastguard Worker            input2 = torch.empty((0,), dtype=torch.double, device=device)
5138*da0073e9SAndroid Build Coastguard Worker            target = torch.empty((0,), dtype=torch.int, device=device)
5139*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, ".*expects 2D.*"):
5140*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5141*da0073e9SAndroid Build Coastguard Worker
5142*da0073e9SAndroid Build Coastguard Worker    def test_cosine_embedding_loss_error_on_nonexpandable_shapes(self):
5143*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5144*da0073e9SAndroid Build Coastguard Worker            input1 = torch.empty((1, 5), dtype=torch.double, device=device)
5145*da0073e9SAndroid Build Coastguard Worker            input2 = torch.empty((1, 6), dtype=torch.double, device=device)
5146*da0073e9SAndroid Build Coastguard Worker            target = torch.ones((1,), dtype=torch.int, device=device)
5147*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, ".*must match the size.*"):
5148*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.cosine_embedding_loss(input1, input2, target)
5149*da0073e9SAndroid Build Coastguard Worker
5150*da0073e9SAndroid Build Coastguard Worker    def test_kl_div_with_diff_type(self):
5151*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5152*da0073e9SAndroid Build Coastguard Worker            input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
5153*da0073e9SAndroid Build Coastguard Worker            target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device)
5154*da0073e9SAndroid Build Coastguard Worker            expected = torch.nn.functional.kl_div(input, target)
5155*da0073e9SAndroid Build Coastguard Worker            real_dtypes = (torch.float32, torch.float64, torch.float16)
5156*da0073e9SAndroid Build Coastguard Worker            for input_dtype, target_dtype in product(real_dtypes, repeat=2):
5157*da0073e9SAndroid Build Coastguard Worker                if (torch.device(device).type == 'cpu' and target_dtype == torch.float16):
5158*da0073e9SAndroid Build Coastguard Worker                    continue
5159*da0073e9SAndroid Build Coastguard Worker                input = input.to(input_dtype)
5160*da0073e9SAndroid Build Coastguard Worker                target = target.to(target_dtype)
5161*da0073e9SAndroid Build Coastguard Worker                result = torch.nn.functional.kl_div(input, target)
5162*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0)
5163*da0073e9SAndroid Build Coastguard Worker
5164*da0073e9SAndroid Build Coastguard Worker    def test_kl_div_with_diff_type_log_target(self):
5165*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5166*da0073e9SAndroid Build Coastguard Worker            input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
5167*da0073e9SAndroid Build Coastguard Worker            target = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.double, device=device).log()
5168*da0073e9SAndroid Build Coastguard Worker            expected = torch.nn.functional.kl_div(input, target, log_target=True)
5169*da0073e9SAndroid Build Coastguard Worker            real_dtypes = (torch.float32, torch.float64, torch.float16)
5170*da0073e9SAndroid Build Coastguard Worker            for input_dtype, target_dtype in product(real_dtypes, repeat=2):
5171*da0073e9SAndroid Build Coastguard Worker                if (torch.device(device).type == 'cpu' and target_dtype == torch.float16):
5172*da0073e9SAndroid Build Coastguard Worker                    continue
5173*da0073e9SAndroid Build Coastguard Worker                input = input.to(input_dtype)
5174*da0073e9SAndroid Build Coastguard Worker                target = target.to(target_dtype)
5175*da0073e9SAndroid Build Coastguard Worker                result = torch.nn.functional.kl_div(input, target, log_target=True)
5176*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0)
5177*da0073e9SAndroid Build Coastguard Worker
5178*da0073e9SAndroid Build Coastguard Worker    def test_kl_div_log_softmax_target(self):
5179*da0073e9SAndroid Build Coastguard Worker        for device in device_():
5180*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[1.0, 2, 3], [5.0, 5, 5]], device=device)
5181*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([[1.0, 2, 3], [5.0, 5, 5]], device=device)
5182*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
5183*da0073e9SAndroid Build Coastguard Worker                F.kl_div(F.log_softmax(a, 1), F.log_softmax(b, 1), reduction='none', log_target=True),
5184*da0073e9SAndroid Build Coastguard Worker                torch.zeros_like(a)
5185*da0073e9SAndroid Build Coastguard Worker            )
5186*da0073e9SAndroid Build Coastguard Worker
5187*da0073e9SAndroid Build Coastguard Worker    def test_cosine_embedding_loss_no_reduce(self):
5188*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5189*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5190*da0073e9SAndroid Build Coastguard Worker        target = torch.randn(15, dtype=torch.double).sign()
5191*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
5192*da0073e9SAndroid Build Coastguard Worker            x, y, z, reduction='none'), (input1, input2, target)))
5193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, reduction='none'),
5194*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, reduction='none'))
5195*da0073e9SAndroid Build Coastguard Worker
5196*da0073e9SAndroid Build Coastguard Worker    def test_cosine_embedding_loss_margin_no_reduce(self):
5197*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5198*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(15, 10, requires_grad=True, dtype=torch.double)
5199*da0073e9SAndroid Build Coastguard Worker        target = torch.randn(15, dtype=torch.double).sign()
5200*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x, y, z: F.cosine_embedding_loss(
5201*da0073e9SAndroid Build Coastguard Worker            x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
5202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.cosine_embedding_loss(input1, input2, target, margin=0.5, reduction='none'),
5203*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target,
5204*da0073e9SAndroid Build Coastguard Worker                                                                   margin=0.5, reduction='none'))
5205*da0073e9SAndroid Build Coastguard Worker
5206*da0073e9SAndroid Build Coastguard Worker    def test_cosine_embedding_loss_invalid_shape(self):
5207*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(15, 10)
5208*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(15, 10)
5209*da0073e9SAndroid Build Coastguard Worker        target = torch.randn(15, 1).sign()
5210*da0073e9SAndroid Build Coastguard Worker
5211*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
5212*da0073e9SAndroid Build Coastguard Worker            F.cosine_embedding_loss(input1, input2, target)
5213*da0073e9SAndroid Build Coastguard Worker
5214*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "1D target tensor expects 2D input tensors"):
5215*da0073e9SAndroid Build Coastguard Worker            F.cosine_embedding_loss(torch.randn(10), torch.randn(10), torch.randn(10))
5216*da0073e9SAndroid Build Coastguard Worker
5217*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "0D target tensor expects 1D input tensors"):
5218*da0073e9SAndroid Build Coastguard Worker            F.cosine_embedding_loss(torch.randn(2, 5), torch.randn(2, 5), torch.randn(()))
5219*da0073e9SAndroid Build Coastguard Worker
5220*da0073e9SAndroid Build Coastguard Worker    def test_margin_ranking_loss_no_reduce(self):
5221*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5222*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5223*da0073e9SAndroid Build Coastguard Worker        target = torch.randn(15, dtype=torch.double).sign()
5224*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
5225*da0073e9SAndroid Build Coastguard Worker            x, y, z, reduction='none'), (input1, input2, target)))
5226*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.margin_ranking_loss(input1, input2, target, reduction='none'),
5227*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, reduction='none'))
5228*da0073e9SAndroid Build Coastguard Worker
5229*da0073e9SAndroid Build Coastguard Worker    def test_margin_ranking_loss_margin_no_reduce(self):
5230*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5231*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(15, dtype=torch.double).mul_(10).requires_grad_()
5232*da0073e9SAndroid Build Coastguard Worker        target = torch.randn(15, dtype=torch.double).sign()
5233*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x, y, z: F.margin_ranking_loss(
5234*da0073e9SAndroid Build Coastguard Worker            x, y, z, margin=0.5, reduction='none'), (input1, input2, target)))
5235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.margin_ranking_loss(input1, input2, target, margin=0.5, reduction='none'),
5236*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['MarginRankingLoss'](input1, input2, target, margin=0.5, reduction='none'))
5237*da0073e9SAndroid Build Coastguard Worker
5238*da0073e9SAndroid Build Coastguard Worker    def test_triplet_margin_loss(self):
5239*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5240*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5241*da0073e9SAndroid Build Coastguard Worker        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5242*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5243*da0073e9SAndroid Build Coastguard Worker            x1, x2, x3), (input1, input2, input3)))
5244*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.triplet_margin_loss(input1, input2, input3),
5245*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3))
5246*da0073e9SAndroid Build Coastguard Worker
5247*da0073e9SAndroid Build Coastguard Worker    def test_triplet_margin_loss_swap(self):
5248*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5249*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5250*da0073e9SAndroid Build Coastguard Worker        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5251*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5252*da0073e9SAndroid Build Coastguard Worker            x1, x2, x3, swap=True), (input1, input2, input3)))
5253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True),
5254*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True))
5255*da0073e9SAndroid Build Coastguard Worker
5256*da0073e9SAndroid Build Coastguard Worker    def test_triplet_margin_loss_no_reduce(self):
5257*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5258*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5259*da0073e9SAndroid Build Coastguard Worker        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5260*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5261*da0073e9SAndroid Build Coastguard Worker            x1, x2, x3, reduction='none'), (input1, input2, input3)))
5262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, reduction='none'),
5263*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, reduction='none'))
5264*da0073e9SAndroid Build Coastguard Worker
5265*da0073e9SAndroid Build Coastguard Worker    def test_triplet_margin_loss_swap_no_reduce(self):
5266*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5267*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5268*da0073e9SAndroid Build Coastguard Worker        input3 = torch.randn(5, 10, requires_grad=True, dtype=torch.double)
5269*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
5270*da0073e9SAndroid Build Coastguard Worker            x1, x2, x3, swap=True, reduction='none'), (input1, input2, input3)))
5271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
5272*da0073e9SAndroid Build Coastguard Worker                         loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
5273*da0073e9SAndroid Build Coastguard Worker
5274*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_loss_target_grad_none_reduction(self):
5275*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(5, 10)
5276*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(5, 10, requires_grad=True)
5277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.mse_loss(i, t, reduction='none').size(), t.size())
5278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.l1_loss(i, t, reduction='none').size(), t.size())
5279*da0073e9SAndroid Build Coastguard Worker
5280*da0073e9SAndroid Build Coastguard Worker    def test_pointwise_loss_broadcast(self):
5281*da0073e9SAndroid Build Coastguard Worker        losses = {
5282*da0073e9SAndroid Build Coastguard Worker            'mse_loss': lambda x, y, r: F.mse_loss(x, y, reduction=r),
5283*da0073e9SAndroid Build Coastguard Worker            'l1_loss': lambda x, y, r: F.l1_loss(x, y, reduction=r),
5284*da0073e9SAndroid Build Coastguard Worker            'smooth_l1_loss': lambda x, y, r: F.smooth_l1_loss(x, y, reduction=r),
5285*da0073e9SAndroid Build Coastguard Worker            'huber_loss': lambda x, y, r: F.huber_loss(x, y, reduction=r),
5286*da0073e9SAndroid Build Coastguard Worker        }
5287*da0073e9SAndroid Build Coastguard Worker
5288*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 1, requires_grad=True, dtype=torch.double)
5289*da0073e9SAndroid Build Coastguard Worker        for fn in losses.values():
5290*da0073e9SAndroid Build Coastguard Worker            for requires_grad in [True, False]:
5291*da0073e9SAndroid Build Coastguard Worker                # When target.requires_grad=True, its impl is in Python, while the other is in TH.
5292*da0073e9SAndroid Build Coastguard Worker                target = torch.randn(2, 10, requires_grad=requires_grad, dtype=torch.double)
5293*da0073e9SAndroid Build Coastguard Worker                for reduction in ['none', 'mean', 'sum']:
5294*da0073e9SAndroid Build Coastguard Worker                    l = fn(input, target, reduction)
5295*da0073e9SAndroid Build Coastguard Worker                    if reduction == 'none':
5296*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(l.size(), target.size())
5297*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(gradcheck(fn, (input, target, reduction)))
5298*da0073e9SAndroid Build Coastguard Worker
5299*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/27692 reports
5300*da0073e9SAndroid Build Coastguard Worker    # that l1_loss get a wrong result for big batch size
5301*da0073e9SAndroid Build Coastguard Worker    def test_l1_loss_correct(self):
5302*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.cfloat]:
5303*da0073e9SAndroid Build Coastguard Worker            for N in range(1, 50, 10):
5304*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(N, 3, 1024, 1024, dtype=dtype)
5305*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
5306*da0073e9SAndroid Build Coastguard Worker                    torch.nn.L1Loss()(input, torch.zeros_like(input)),
5307*da0073e9SAndroid Build Coastguard Worker                    input.abs().mean())
5308*da0073e9SAndroid Build Coastguard Worker
5309*da0073e9SAndroid Build Coastguard Worker    def test_smoothl1loss_intergral_target(self):
5310*da0073e9SAndroid Build Coastguard Worker        def _input_grad(input, target, reduction):
5311*da0073e9SAndroid Build Coastguard Worker            output = F.smooth_l1_loss(input, target, reduction=reduction, beta=0.5)
5312*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
5313*da0073e9SAndroid Build Coastguard Worker            return input.grad
5314*da0073e9SAndroid Build Coastguard Worker
5315*da0073e9SAndroid Build Coastguard Worker        for device, dtype, reduction in product(device_(),
5316*da0073e9SAndroid Build Coastguard Worker                                                integral_types(),
5317*da0073e9SAndroid Build Coastguard Worker                                                ('none', 'sum', 'mean')):
5318*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(2, 2, device=device, requires_grad=True)
5319*da0073e9SAndroid Build Coastguard Worker            target = torch.randint(0, 9, (2, 2), device=device, dtype=dtype)
5320*da0073e9SAndroid Build Coastguard Worker
5321*da0073e9SAndroid Build Coastguard Worker            input_grad_with_float_target = _input_grad(input, target.float(), reduction)
5322*da0073e9SAndroid Build Coastguard Worker
5323*da0073e9SAndroid Build Coastguard Worker            input_grad = _input_grad(input.detach().clone().requires_grad_(True),
5324*da0073e9SAndroid Build Coastguard Worker                                     target,
5325*da0073e9SAndroid Build Coastguard Worker                                     reduction)
5326*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_grad, input_grad_with_float_target)
5327*da0073e9SAndroid Build Coastguard Worker
5328*da0073e9SAndroid Build Coastguard Worker    def test_smoothl1loss_negative_beta_not_supported(self):
5329*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
5330*da0073e9SAndroid Build Coastguard Worker            F.smooth_l1_loss(torch.randn(2, 2), torch.randn(2, 2), beta=-1.0)
5331*da0073e9SAndroid Build Coastguard Worker
5332*da0073e9SAndroid Build Coastguard Worker    def test_huber_loss_invalid_delta(self):
5333*da0073e9SAndroid Build Coastguard Worker        def _test_huber_loss_delta_error_helper(delta):
5334*da0073e9SAndroid Build Coastguard Worker            input, target = torch.randn(2, 2), torch.randn(2, 2)
5335*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.HuberLoss(delta=delta)
5336*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
5337*da0073e9SAndroid Build Coastguard Worker                loss(input, target)
5338*da0073e9SAndroid Build Coastguard Worker
5339*da0073e9SAndroid Build Coastguard Worker        def test_huber_loss_negative_delta():
5340*da0073e9SAndroid Build Coastguard Worker            _test_huber_loss_delta_error_helper(delta=-0.5)
5341*da0073e9SAndroid Build Coastguard Worker
5342*da0073e9SAndroid Build Coastguard Worker        def test_huber_loss_zero_delta():
5343*da0073e9SAndroid Build Coastguard Worker            _test_huber_loss_delta_error_helper(delta=0.0)
5344*da0073e9SAndroid Build Coastguard Worker
5345*da0073e9SAndroid Build Coastguard Worker        test_huber_loss_negative_delta()
5346*da0073e9SAndroid Build Coastguard Worker        test_huber_loss_zero_delta()
5347*da0073e9SAndroid Build Coastguard Worker
5348*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
5349*da0073e9SAndroid Build Coastguard Worker    def test_cosine_similarity(self):
5350*da0073e9SAndroid Build Coastguard Worker        # Check cosine_similarity input/output shapes
5351*da0073e9SAndroid Build Coastguard Worker        input_size = (1, 3, 2, 1)
5352*da0073e9SAndroid Build Coastguard Worker        expected_size = (1, 2, 1)
5353*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(input_size, requires_grad=True)
5354*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(input_size, requires_grad=True)
5355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.cosine_similarity(input1, input2, dim=1).size(), expected_size)
5356*da0073e9SAndroid Build Coastguard Worker
5357*da0073e9SAndroid Build Coastguard Worker        # Check numerical precision, issue #18057
5358*da0073e9SAndroid Build Coastguard Worker        vv1 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0)
5359*da0073e9SAndroid Build Coastguard Worker        vv2 = torch.tensor([float(i) for i in range(84)]).unsqueeze(0)
5360*da0073e9SAndroid Build Coastguard Worker        out = F.cosine_similarity(vv1, vv2)
5361*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(out, 1.0)
5362*da0073e9SAndroid Build Coastguard Worker
5363*da0073e9SAndroid Build Coastguard Worker        # Check dividing by 0.
5364*da0073e9SAndroid Build Coastguard Worker        # previous behavior: <x,y>/max(eps, ||x|| * ||y||)
5365*da0073e9SAndroid Build Coastguard Worker        # current: <x/max(eps, ||x||), y/max(eps,||y||)>
5366*da0073e9SAndroid Build Coastguard Worker        # if f(x,y) is the cosine similarity, then
5367*da0073e9SAndroid Build Coastguard Worker        # df/dx = y/(||x|| * ||y||) - (x * <x,y> * ||y||/||x||)/(||x|| * ||y||)^2
5368*da0073e9SAndroid Build Coastguard Worker        # the tests below check division by zero in the backward formula when
5369*da0073e9SAndroid Build Coastguard Worker        # x := input2 = 0, y := input1 != 0.
5370*da0073e9SAndroid Build Coastguard Worker        # For these inputs the gradient wrt x simplifies to g(x,y) := y/(||x|| * ||y||)
5371*da0073e9SAndroid Build Coastguard Worker        # Previous test checks g(x,y) == y/eps,
5372*da0073e9SAndroid Build Coastguard Worker        # Current test checks g(x,y) == (y/||y||)/eps.
5373*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(10).requires_grad_()
5374*da0073e9SAndroid Build Coastguard Worker        input2 = torch.zeros_like(input1).requires_grad_()
5375*da0073e9SAndroid Build Coastguard Worker        torch.cosine_similarity(input1, input2, 0).sum().backward()
5376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input1.grad, torch.zeros_like(input1))
5377*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input2.grad, input1 / input1.norm() * 1e8)
5378*da0073e9SAndroid Build Coastguard Worker
5379*da0073e9SAndroid Build Coastguard Worker        # Check type promotion, issue #61454
5380*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor(12.)
5381*da0073e9SAndroid Build Coastguard Worker        out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
5382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, 1.)
5383*da0073e9SAndroid Build Coastguard Worker
5384*da0073e9SAndroid Build Coastguard Worker        # Check broadcasting #109333
5385*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 3, dtype=torch.float)
5386*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(1, 1, dtype=torch.float)
5387*da0073e9SAndroid Build Coastguard Worker        out = F.cosine_similarity(a, b)
5388*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.ones(2, dtype=torch.float))
5389*da0073e9SAndroid Build Coastguard Worker
5390*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 3, dtype=torch.float)
5391*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(1, dtype=torch.float)
5392*da0073e9SAndroid Build Coastguard Worker        out = F.cosine_similarity(a, b)
5393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.ones(2, dtype=torch.float))
5394*da0073e9SAndroid Build Coastguard Worker
5395*da0073e9SAndroid Build Coastguard Worker
5396*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_error_checking(self):
5397*da0073e9SAndroid Build Coastguard Worker        input = torch.empty(1, 1, 2, 2)
5398*da0073e9SAndroid Build Coastguard Worker        grid = torch.empty(1, 1, 1, 2)
5399*da0073e9SAndroid Build Coastguard Worker
5400*da0073e9SAndroid Build Coastguard Worker        # assert no error
5401*da0073e9SAndroid Build Coastguard Worker        F.grid_sample(input, grid, align_corners=False)
5402*da0073e9SAndroid Build Coastguard Worker
5403*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "but got: 'garbage'"):
5404*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(input, grid, mode='garbage', align_corners=False)
5405*da0073e9SAndroid Build Coastguard Worker
5406*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "but got: 'garbage'"):
5407*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(input, grid, padding_mode='garbage', align_corners=False)
5408*da0073e9SAndroid Build Coastguard Worker
5409*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected grid to have size 1 in last dimension"):
5410*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(input[0], grid, align_corners=False)
5411*da0073e9SAndroid Build Coastguard Worker
5412*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected grid to have size 2 in last dimension"):
5413*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(input, torch.empty(1, 1, 1, 1, 3), align_corners=False)
5414*da0073e9SAndroid Build Coastguard Worker
5415*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected grid and input to have same batch size"):
5416*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(input, torch.empty(2, 1, 1, 2), align_corners=False)
5417*da0073e9SAndroid Build Coastguard Worker
5418*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected grid to have size 2 in last dimension"):
5419*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(input, torch.empty(1, 1, 1, 3), align_corners=False)
5420*da0073e9SAndroid Build Coastguard Worker
5421*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected input to have non-empty spatial dimensions"):
5422*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(torch.empty(1, 1, 0, 2), grid, align_corners=False)
5423*da0073e9SAndroid Build Coastguard Worker
5424*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "bicubic interpolation only supports 4D input"):
5425*da0073e9SAndroid Build Coastguard Worker            F.grid_sample(torch.empty(1, 1, 2, 2, 2), torch.empty(1, 1, 1, 1, 3), mode='bicubic')
5426*da0073e9SAndroid Build Coastguard Worker
5427*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
5428*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5429*da0073e9SAndroid Build Coastguard Worker                F.grid_sample(input.cuda(), grid, align_corners=False)
5430*da0073e9SAndroid Build Coastguard Worker
5431*da0073e9SAndroid Build Coastguard Worker    def test_affine_grid_error_checking(self):
5432*da0073e9SAndroid Build Coastguard Worker        # 2D affine
5433*da0073e9SAndroid Build Coastguard Worker        theta = torch.empty(1, 2, 3, dtype=torch.double)
5434*da0073e9SAndroid Build Coastguard Worker        size = torch.Size([1, 1, 2, 2])
5435*da0073e9SAndroid Build Coastguard Worker
5436*da0073e9SAndroid Build Coastguard Worker        # assert no error
5437*da0073e9SAndroid Build Coastguard Worker        F.affine_grid(theta, size, align_corners=False)
5438*da0073e9SAndroid Build Coastguard Worker
5439*da0073e9SAndroid Build Coastguard Worker        # check for warning for empty span along dimension
5440*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
5441*da0073e9SAndroid Build Coastguard Worker            # Ensure warnings are being shown
5442*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
5443*da0073e9SAndroid Build Coastguard Worker            # Should not trigger warning
5444*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta, torch.Size([1, 1, 2, 1]), align_corners=False)
5445*da0073e9SAndroid Build Coastguard Worker            # Check no warning occurs
5446*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5447*da0073e9SAndroid Build Coastguard Worker            # Should trigger warning
5448*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta, torch.Size([1, 1, 2, 1]), align_corners=True)
5449*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
5450*da0073e9SAndroid Build Coastguard Worker            self.assertIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5451*da0073e9SAndroid Build Coastguard Worker
5452*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected theta to have floating point type"):
5453*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta.int(), size, align_corners=False)
5454*da0073e9SAndroid Build Coastguard Worker
5455*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5456*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta[0], size, align_corners=False)
5457*da0073e9SAndroid Build Coastguard Worker
5458*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5459*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta.unsqueeze(0), size, align_corners=False)
5460*da0073e9SAndroid Build Coastguard Worker
5461*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5462*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta.repeat(1, 2, 1), size, align_corners=False)
5463*da0073e9SAndroid Build Coastguard Worker
5464*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 2D affine matrices of shape Nx2x3"):
5465*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta.repeat(1, 1, 2), size, align_corners=False)
5466*da0073e9SAndroid Build Coastguard Worker
5467*da0073e9SAndroid Build Coastguard Worker        # 3D affine
5468*da0073e9SAndroid Build Coastguard Worker        theta = torch.empty(1, 3, 4, dtype=torch.double)
5469*da0073e9SAndroid Build Coastguard Worker        size = torch.Size([1, 1, 2, 2, 2])
5470*da0073e9SAndroid Build Coastguard Worker
5471*da0073e9SAndroid Build Coastguard Worker        # assert no error
5472*da0073e9SAndroid Build Coastguard Worker        F.affine_grid(theta, size, align_corners=False)
5473*da0073e9SAndroid Build Coastguard Worker
5474*da0073e9SAndroid Build Coastguard Worker        # check for warning for empty span along dimension
5475*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
5476*da0073e9SAndroid Build Coastguard Worker            # Ensure warnings are being shown
5477*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
5478*da0073e9SAndroid Build Coastguard Worker            # Should not trigger warning
5479*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta, torch.Size([1, 1, 3, 2, 1]), align_corners=False)
5480*da0073e9SAndroid Build Coastguard Worker            # Check no warning occurs
5481*da0073e9SAndroid Build Coastguard Worker            self.assertNotIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5482*da0073e9SAndroid Build Coastguard Worker            # Should trigger warning
5483*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta, torch.Size([1, 1, 3, 2, 1]), align_corners=True)
5484*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
5485*da0073e9SAndroid Build Coastguard Worker            self.assertIn('See the documentation of affine_grid for details.', ' '.join(map(str, w)))
5486*da0073e9SAndroid Build Coastguard Worker
5487*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5488*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta[0], size, align_corners=False)
5489*da0073e9SAndroid Build Coastguard Worker
5490*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5491*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta.unsqueeze(0), size, align_corners=False)
5492*da0073e9SAndroid Build Coastguard Worker
5493*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5494*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta.repeat(1, 2, 1), size, align_corners=False)
5495*da0073e9SAndroid Build Coastguard Worker
5496*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected a batch of 3D affine matrices of shape Nx3x4"):
5497*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta.repeat(1, 1, 2), size, align_corners=False)
5498*da0073e9SAndroid Build Coastguard Worker
5499*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "affine_grid only supports 4D and 5D sizes"):
5500*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta, torch.Size([1, 2, 2]), align_corners=False)
5501*da0073e9SAndroid Build Coastguard Worker
5502*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "affine_grid only supports 4D and 5D sizes"):
5503*da0073e9SAndroid Build Coastguard Worker            F.affine_grid(theta, torch.Size([1, 1, 2, 2, 2, 2]), align_corners=False)
5504*da0073e9SAndroid Build Coastguard Worker
5505*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else []))
5506*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('nd', [2, 3])
5507*da0073e9SAndroid Build Coastguard Worker    def test_affine_grid_backward_cl_cf_consistency(self, device, nd):
5508*da0073e9SAndroid Build Coastguard Worker        # Test based on reported issue: https://github.com/pytorch/pytorch/issues/124154
5509*da0073e9SAndroid Build Coastguard Worker
5510*da0073e9SAndroid Build Coastguard Worker        theta = torch.rand([6, nd, nd + 1], requires_grad=True, device=device)
5511*da0073e9SAndroid Build Coastguard Worker        size = [6, 3, 4, 5] if nd == 2 else [6, 3, 4, 5, 5]
5512*da0073e9SAndroid Build Coastguard Worker        grid = torch.nn.functional.affine_grid(theta, size, align_corners=False)
5513*da0073e9SAndroid Build Coastguard Worker
5514*da0073e9SAndroid Build Coastguard Worker        grad_tensor = torch.rand(grid.shape, device=device)
5515*da0073e9SAndroid Build Coastguard Worker
5516*da0073e9SAndroid Build Coastguard Worker        memory_format_cl = torch.channels_last if nd == 2 else torch.channels_last_3d
5517*da0073e9SAndroid Build Coastguard Worker        grad_tensor_cl = grad_tensor.contiguous(memory_format=memory_format_cl)
5518*da0073e9SAndroid Build Coastguard Worker
5519*da0073e9SAndroid Build Coastguard Worker        assert theta.grad is None
5520*da0073e9SAndroid Build Coastguard Worker        grid.backward(grad_tensor_cl)
5521*da0073e9SAndroid Build Coastguard Worker        theta_grad_cl = theta.grad.clone().contiguous()
5522*da0073e9SAndroid Build Coastguard Worker
5523*da0073e9SAndroid Build Coastguard Worker        theta.grad.zero_()
5524*da0073e9SAndroid Build Coastguard Worker        grid.backward(grad_tensor)
5525*da0073e9SAndroid Build Coastguard Worker        theta_grad_cf = theta.grad
5526*da0073e9SAndroid Build Coastguard Worker
5527*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(theta_grad_cf, theta_grad_cl)
5528*da0073e9SAndroid Build Coastguard Worker
5529*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
5530*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample(self):
5531*da0073e9SAndroid Build Coastguard Worker        # Backward pass of native C++ and CUDA kernels branch depending on whether input requires gradient,
5532*da0073e9SAndroid Build Coastguard Worker        # so we test both cases.
5533*da0073e9SAndroid Build Coastguard Worker        def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
5534*da0073e9SAndroid Build Coastguard Worker            def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
5535*da0073e9SAndroid Build Coastguard Worker                for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
5536*da0073e9SAndroid Build Coastguard Worker                    # grid_dim_contig_order specifies the dimension order that can
5537*da0073e9SAndroid Build Coastguard Worker                    # make grid to be contiguous.
5538*da0073e9SAndroid Build Coastguard Worker                    # i.e., grid.permute(grid_dim_contig_order) is contiguous.
5539*da0073e9SAndroid Build Coastguard Worker                    # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
5540*da0073e9SAndroid Build Coastguard Worker                    #       initialized with contiguous tensor of shape [N, 2, H, W]
5541*da0073e9SAndroid Build Coastguard Worker                    #       and permuted to [N, H, W, 2] afterwards.
5542*da0073e9SAndroid Build Coastguard Worker                    grid_shape = [N, H, W, 2]
5543*da0073e9SAndroid Build Coastguard Worker                    grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
5544*da0073e9SAndroid Build Coastguard Worker                    grid_fwd_permute = [None, None, None, None]
5545*da0073e9SAndroid Build Coastguard Worker                    for i, d in enumerate(grid_dim_contig_order):
5546*da0073e9SAndroid Build Coastguard Worker                        grid_fwd_permute[d] = i
5547*da0073e9SAndroid Build Coastguard Worker
5548*da0073e9SAndroid Build Coastguard Worker                    def get_grid(device='cpu', data=None):
5549*da0073e9SAndroid Build Coastguard Worker                        if data is not None:
5550*da0073e9SAndroid Build Coastguard Worker                            assert list(data.shape) == grid_shape
5551*da0073e9SAndroid Build Coastguard Worker                            data = data.permute(grid_dim_contig_order).to(device)
5552*da0073e9SAndroid Build Coastguard Worker                        else:
5553*da0073e9SAndroid Build Coastguard Worker                            data = torch.randn(grid_init_shape, device=device)
5554*da0073e9SAndroid Build Coastguard Worker                        grid = data.permute(grid_fwd_permute)
5555*da0073e9SAndroid Build Coastguard Worker                        assert grid.permute(grid_dim_contig_order).is_contiguous()
5556*da0073e9SAndroid Build Coastguard Worker                        return grid
5557*da0073e9SAndroid Build Coastguard Worker
5558*da0073e9SAndroid Build Coastguard Worker                    input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
5559*da0073e9SAndroid Build Coastguard Worker                    grid_cpu = get_grid().requires_grad_()
5560*da0073e9SAndroid Build Coastguard Worker                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5561*da0073e9SAndroid Build Coastguard Worker                                            align_corners=align_corners)
5562*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(out_cpu.size() == torch.Size([N, C, H, W]))
5563*da0073e9SAndroid Build Coastguard Worker
5564*da0073e9SAndroid Build Coastguard Worker                    gradients = torch.randn_like(out_cpu)
5565*da0073e9SAndroid Build Coastguard Worker                    out_cpu.backward(gradients)
5566*da0073e9SAndroid Build Coastguard Worker
5567*da0073e9SAndroid Build Coastguard Worker
5568*da0073e9SAndroid Build Coastguard Worker                    # Compare against unvectorized CPU fallback
5569*da0073e9SAndroid Build Coastguard Worker
5570*da0073e9SAndroid Build Coastguard Worker                    # NOTE [ grid_sample CPU fallback ]
5571*da0073e9SAndroid Build Coastguard Worker                    # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
5572*da0073e9SAndroid Build Coastguard Worker                    # 32-bit floats. So we also have a fallback that is used only for float tensors
5573*da0073e9SAndroid Build Coastguard Worker                    # requiring 64-bit indexing. That requires too much memory to run on CI, so we
5574*da0073e9SAndroid Build Coastguard Worker                    # also export the fallback and test it here to ensure feature parity with
5575*da0073e9SAndroid Build Coastguard Worker                    # the vectorized version.
5576*da0073e9SAndroid Build Coastguard Worker                    input_fallback = input_cpu.float().detach_().requires_grad_()
5577*da0073e9SAndroid Build Coastguard Worker                    grid_fallback = grid_cpu.float().detach_().requires_grad_()
5578*da0073e9SAndroid Build Coastguard Worker                    out_fallback = torch._grid_sampler_2d_cpu_fallback(
5579*da0073e9SAndroid Build Coastguard Worker                        input_fallback, grid_fallback,
5580*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
5581*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
5582*da0073e9SAndroid Build Coastguard Worker                        align_corners)
5583*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
5584*da0073e9SAndroid Build Coastguard Worker
5585*da0073e9SAndroid Build Coastguard Worker                    out_fallback.backward(gradients.float())
5586*da0073e9SAndroid Build Coastguard Worker                    if input_requires_grad:
5587*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
5588*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
5589*da0073e9SAndroid Build Coastguard Worker
5590*da0073e9SAndroid Build Coastguard Worker                    if TEST_CUDA:
5591*da0073e9SAndroid Build Coastguard Worker                        input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_(input_requires_grad)
5592*da0073e9SAndroid Build Coastguard Worker                        grid_cuda = get_grid('cuda', grid_cpu.detach()).requires_grad_()
5593*da0073e9SAndroid Build Coastguard Worker                        out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5594*da0073e9SAndroid Build Coastguard Worker                                                 align_corners=align_corners)
5595*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(out_cpu, out_cuda)
5596*da0073e9SAndroid Build Coastguard Worker
5597*da0073e9SAndroid Build Coastguard Worker                        out_cuda.backward(gradients.cuda())
5598*da0073e9SAndroid Build Coastguard Worker                        if input_requires_grad:
5599*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(input_cpu.grad, input_cuda.grad)
5600*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5, rtol=0)
5601*da0073e9SAndroid Build Coastguard Worker
5602*da0073e9SAndroid Build Coastguard Worker                        # check that zero-dimensional input strides don't error out
5603*da0073e9SAndroid Build Coastguard Worker                        base_input = torch.randn(N, C, 1, IW)
5604*da0073e9SAndroid Build Coastguard Worker                        input_cpu = base_input.expand_as(input_cuda).requires_grad_(input_requires_grad)
5605*da0073e9SAndroid Build Coastguard Worker                        out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5606*da0073e9SAndroid Build Coastguard Worker                                                align_corners=align_corners)
5607*da0073e9SAndroid Build Coastguard Worker
5608*da0073e9SAndroid Build Coastguard Worker                        input_cuda = base_input.cuda().expand_as(input_cuda).requires_grad_(input_requires_grad)
5609*da0073e9SAndroid Build Coastguard Worker                        out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5610*da0073e9SAndroid Build Coastguard Worker                                                 align_corners=align_corners)
5611*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(out_cpu, out_cuda)
5612*da0073e9SAndroid Build Coastguard Worker
5613*da0073e9SAndroid Build Coastguard Worker            # test same size output
5614*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
5615*da0073e9SAndroid Build Coastguard Worker
5616*da0073e9SAndroid Build Coastguard Worker            # test larger output
5617*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
5618*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
5619*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
5620*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
5621*da0073e9SAndroid Build Coastguard Worker            H = random.randint(IH + 1, 12)
5622*da0073e9SAndroid Build Coastguard Worker            W = random.randint(IW + 1, 12)
5623*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
5624*da0073e9SAndroid Build Coastguard Worker
5625*da0073e9SAndroid Build Coastguard Worker            # test smaller output
5626*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
5627*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
5628*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
5629*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
5630*da0073e9SAndroid Build Coastguard Worker            H = random.randint(2, IH)
5631*da0073e9SAndroid Build Coastguard Worker            W = random.randint(2, IW)
5632*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
5633*da0073e9SAndroid Build Coastguard Worker
5634*da0073e9SAndroid Build Coastguard Worker            # test 1x1 inpput
5635*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
5636*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
5637*da0073e9SAndroid Build Coastguard Worker            IH = 1
5638*da0073e9SAndroid Build Coastguard Worker            IW = 1
5639*da0073e9SAndroid Build Coastguard Worker            H = random.randint(2, 5)
5640*da0073e9SAndroid Build Coastguard Worker            W = random.randint(2, 5)
5641*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
5642*da0073e9SAndroid Build Coastguard Worker
5643*da0073e9SAndroid Build Coastguard Worker            # testing empty grid
5644*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
5645*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
5646*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
5647*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
5648*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
5649*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
5650*da0073e9SAndroid Build Coastguard Worker
5651*da0073e9SAndroid Build Coastguard Worker            # testing empty channel
5652*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
5653*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
5654*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
5655*da0073e9SAndroid Build Coastguard Worker            H = random.randint(3, IH + 2)
5656*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
5657*da0073e9SAndroid Build Coastguard Worker            test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
5658*da0073e9SAndroid Build Coastguard Worker
5659*da0073e9SAndroid Build Coastguard Worker            # testing empty batch
5660*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
5661*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
5662*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
5663*da0073e9SAndroid Build Coastguard Worker            H = random.randint(3, IH + 2)
5664*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
5665*da0073e9SAndroid Build Coastguard Worker            test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
5666*da0073e9SAndroid Build Coastguard Worker
5667*da0073e9SAndroid Build Coastguard Worker        for mode in ('bilinear', 'nearest', 'bicubic'):
5668*da0073e9SAndroid Build Coastguard Worker            for padding_mode in ('zeros', 'border', 'reflection'):
5669*da0073e9SAndroid Build Coastguard Worker                for align_corners in (True, False):
5670*da0073e9SAndroid Build Coastguard Worker                    # test known input on CPU
5671*da0073e9SAndroid Build Coastguard Worker                    input = torch.arange(1., 11).view(1, 1, 2, 5)
5672*da0073e9SAndroid Build Coastguard Worker                    grid = torch.tensor(
5673*da0073e9SAndroid Build Coastguard Worker                        [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
5674*da0073e9SAndroid Build Coastguard Worker                         [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]]).view(1, 2, 5, 2)
5675*da0073e9SAndroid Build Coastguard Worker                    if mode == 'bilinear':
5676*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
5677*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5678*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5679*da0073e9SAndroid Build Coastguard Worker                                    [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
5680*da0073e9SAndroid Build Coastguard Worker                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]]).view(1, 1, 2, 5)
5681*da0073e9SAndroid Build Coastguard Worker                            else:
5682*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5683*da0073e9SAndroid Build Coastguard Worker                                    [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
5684*da0073e9SAndroid Build Coastguard Worker                                     [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]]).view(1, 1, 2, 5)
5685*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
5686*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5687*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5688*da0073e9SAndroid Build Coastguard Worker                                    [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
5689*da0073e9SAndroid Build Coastguard Worker                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]]).view(1, 1, 2, 5)
5690*da0073e9SAndroid Build Coastguard Worker                            else:
5691*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5692*da0073e9SAndroid Build Coastguard Worker                                    [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
5693*da0073e9SAndroid Build Coastguard Worker                                     [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]]).view(1, 1, 2, 5)
5694*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
5695*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5696*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5697*da0073e9SAndroid Build Coastguard Worker                                    [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
5698*da0073e9SAndroid Build Coastguard Worker                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]]).view(1, 1, 2, 5)
5699*da0073e9SAndroid Build Coastguard Worker                            else:
5700*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5701*da0073e9SAndroid Build Coastguard Worker                                    [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
5702*da0073e9SAndroid Build Coastguard Worker                                     [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]]).view(1, 1, 2, 5)
5703*da0073e9SAndroid Build Coastguard Worker                        else:
5704*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
5705*da0073e9SAndroid Build Coastguard Worker                    elif mode == 'nearest':
5706*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
5707*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5708*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5709*da0073e9SAndroid Build Coastguard Worker                                    [[0., 8., 5., 7., 9.],
5710*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 0.]]).view(1, 1, 2, 5)
5711*da0073e9SAndroid Build Coastguard Worker                            else:
5712*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5713*da0073e9SAndroid Build Coastguard Worker                                    [[0., 8., 5., 7., 0.],
5714*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 0.]]).view(1, 1, 2, 5)
5715*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
5716*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5717*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5718*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
5719*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 10.]]).view(1, 1, 2, 5)
5720*da0073e9SAndroid Build Coastguard Worker                            else:
5721*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5722*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
5723*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 10.]]).view(1, 1, 2, 5)
5724*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
5725*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5726*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5727*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
5728*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5)
5729*da0073e9SAndroid Build Coastguard Worker                            else:
5730*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5731*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
5732*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 9.]]).view(1, 1, 2, 5)
5733*da0073e9SAndroid Build Coastguard Worker                        else:
5734*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
5735*da0073e9SAndroid Build Coastguard Worker                    elif mode == 'bicubic':
5736*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
5737*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5738*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5739*da0073e9SAndroid Build Coastguard Worker                                    [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
5740*da0073e9SAndroid Build Coastguard Worker                                     [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]]).view(1, 1, 2, 5)
5741*da0073e9SAndroid Build Coastguard Worker                            else:
5742*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5743*da0073e9SAndroid Build Coastguard Worker                                    [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
5744*da0073e9SAndroid Build Coastguard Worker                                     [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]]).view(1, 1, 2, 5)
5745*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
5746*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5747*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5748*da0073e9SAndroid Build Coastguard Worker                                    [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
5749*da0073e9SAndroid Build Coastguard Worker                                     [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]]).view(1, 1, 2, 5)
5750*da0073e9SAndroid Build Coastguard Worker                            else:
5751*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5752*da0073e9SAndroid Build Coastguard Worker                                    [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
5753*da0073e9SAndroid Build Coastguard Worker                                     [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]]).view(1, 1, 2, 5)
5754*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
5755*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5756*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5757*da0073e9SAndroid Build Coastguard Worker                                    [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
5758*da0073e9SAndroid Build Coastguard Worker                                     [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]]).view(1, 1, 2, 5)
5759*da0073e9SAndroid Build Coastguard Worker                            else:
5760*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5761*da0073e9SAndroid Build Coastguard Worker                                    [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
5762*da0073e9SAndroid Build Coastguard Worker                                     [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]]).view(1, 1, 2, 5)
5763*da0073e9SAndroid Build Coastguard Worker                        else:
5764*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
5765*da0073e9SAndroid Build Coastguard Worker
5766*da0073e9SAndroid Build Coastguard Worker                    else:
5767*da0073e9SAndroid Build Coastguard Worker                        raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
5768*da0073e9SAndroid Build Coastguard Worker                    output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
5769*da0073e9SAndroid Build Coastguard Worker                                           align_corners=align_corners)
5770*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
5771*da0073e9SAndroid Build Coastguard Worker                                     msg=f"groundtruth comparison failed for mode={mode}, "
5772*da0073e9SAndroid Build Coastguard Worker                                     f"padding_mode={padding_mode}")
5773*da0073e9SAndroid Build Coastguard Worker
5774*da0073e9SAndroid Build Coastguard Worker                    # See NOTE [ grid_sample CPU fallback ]
5775*da0073e9SAndroid Build Coastguard Worker                    output = torch._grid_sampler_2d_cpu_fallback(
5776*da0073e9SAndroid Build Coastguard Worker                        input.float(), grid.float(),
5777*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
5778*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
5779*da0073e9SAndroid Build Coastguard Worker                        align_corners)
5780*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output, groundtruth.float(), atol=1e-5, rtol=0)
5781*da0073e9SAndroid Build Coastguard Worker
5782*da0073e9SAndroid Build Coastguard Worker                    # explicit check for gradient edge cases
5783*da0073e9SAndroid Build Coastguard Worker                    input = torch.arange(0., 5).expand((1, 1, 5, 5))
5784*da0073e9SAndroid Build Coastguard Worker                    grid = torch.tensor(
5785*da0073e9SAndroid Build Coastguard Worker                        [[[1.0, 1.0], [1.0, -1.0], [0.8, 0.8], [0.8, -0.8]],
5786*da0073e9SAndroid Build Coastguard Worker                         [[-1.0, -1.0], [-1.0, 1.0], [-0.8, -0.8], [-0.8, 0.8]]]).view(1, 2, 4, 2).requires_grad_()
5787*da0073e9SAndroid Build Coastguard Worker                    if mode == 'bilinear':
5788*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
5789*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5790*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5791*da0073e9SAndroid Build Coastguard Worker                                    [[[[-8., -8.], [-8., 0.], [2., 0.], [2., 0.]],
5792*da0073e9SAndroid Build Coastguard Worker                                      [[2., 0.], [2., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2)
5793*da0073e9SAndroid Build Coastguard Worker                            else:
5794*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5795*da0073e9SAndroid Build Coastguard Worker                                    [[[[-5., -5.], [-5., 5.], [-10., -10.], [-10., 10.]],
5796*da0073e9SAndroid Build Coastguard Worker                                      [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5797*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
5798*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5799*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5800*da0073e9SAndroid Build Coastguard Worker                                    [[[[-0., -0.], [-0., 0.], [2., 0.], [2., 0.]],
5801*da0073e9SAndroid Build Coastguard Worker                                      [[0., 0.], [0., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2)
5802*da0073e9SAndroid Build Coastguard Worker                            else:
5803*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5804*da0073e9SAndroid Build Coastguard Worker                                    [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]],
5805*da0073e9SAndroid Build Coastguard Worker                                      [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5806*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
5807*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5808*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5809*da0073e9SAndroid Build Coastguard Worker                                    [[[[-0., -0.], [-0., 0.], [2., 0.], [2., 0.]],
5810*da0073e9SAndroid Build Coastguard Worker                                      [[0., 0.], [0., 0.], [2., 0.], [2., 0.]]]]).view(1, 2, 4, 2)
5811*da0073e9SAndroid Build Coastguard Worker                            else:
5812*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5813*da0073e9SAndroid Build Coastguard Worker                                    [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]],
5814*da0073e9SAndroid Build Coastguard Worker                                      [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5815*da0073e9SAndroid Build Coastguard Worker                        else:
5816*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing gradient groundtruth test for padding mode '{padding_mode}'")
5817*da0073e9SAndroid Build Coastguard Worker                    elif mode == 'nearest':
5818*da0073e9SAndroid Build Coastguard Worker                        groundtruth = torch.tensor(
5819*da0073e9SAndroid Build Coastguard Worker                            [[[[-0., -0.], [-0., 0.], [-0., -0.], [-0., 0.]],
5820*da0073e9SAndroid Build Coastguard Worker                              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]]).view(1, 2, 4, 2)
5821*da0073e9SAndroid Build Coastguard Worker                    elif mode == 'bicubic':
5822*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
5823*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5824*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5825*da0073e9SAndroid Build Coastguard Worker                                    [[[[-4.5, -6.], [-4.5, 6.], [2.725679, 0.740878], [2.725679, -0.740878]],
5826*da0073e9SAndroid Build Coastguard Worker                                      [[1.5, 0.], [1.5, 0.], [1.927921, -0.05688], [1.927921, 0.05688]]]]).view(1, 2, 4, 2)
5827*da0073e9SAndroid Build Coastguard Worker                            else:
5828*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5829*da0073e9SAndroid Build Coastguard Worker                                    [[[[-5.859375, -5.888672], [-5.859375, 5.888672], [-5.6250, -7.5000], [-5.6250, 7.5000]],
5830*da0073e9SAndroid Build Coastguard Worker                                      [[-0.234375, -0.263672], [-0.234375, 0.263672], [1.8750, 0.], [1.8750, 0.]]]]
5831*da0073e9SAndroid Build Coastguard Worker                                ).view(1, 2, 4, 2)
5832*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
5833*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5834*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5835*da0073e9SAndroid Build Coastguard Worker                                    [[[[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]],
5836*da0073e9SAndroid Build Coastguard Worker                                      [[1.5, 0.], [1.5, 0.], [1.74, 0.], [1.74, 0.]]]]).view(1, 2, 4, 2)
5837*da0073e9SAndroid Build Coastguard Worker                            else:
5838*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5839*da0073e9SAndroid Build Coastguard Worker                                    [[[[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]],
5840*da0073e9SAndroid Build Coastguard Worker                                      [[-0.46875, 0.], [-0.46875, 0.], [1.8750, 0.], [1.8750, 0.]]]]).view(1, 2, 4, 2)
5841*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
5842*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
5843*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5844*da0073e9SAndroid Build Coastguard Worker                                    [[[[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]],
5845*da0073e9SAndroid Build Coastguard Worker                                      [[0., 0.], [0., 0.], [1.92, 0.], [1.92, 0.]]]]).view(1, 2, 4, 2)
5846*da0073e9SAndroid Build Coastguard Worker                            else:
5847*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
5848*da0073e9SAndroid Build Coastguard Worker                                    [[[[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]],
5849*da0073e9SAndroid Build Coastguard Worker                                      [[0., 0.], [0., 0.], [1.875, 0.], [1.875, 0.]]]]).view(1, 2, 4, 2)
5850*da0073e9SAndroid Build Coastguard Worker                        else:
5851*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing gradient groundtruth test for padding mode '{padding_mode}'")
5852*da0073e9SAndroid Build Coastguard Worker                    else:
5853*da0073e9SAndroid Build Coastguard Worker                        raise AssertionError(f"missing gradient groundtruth test for interpolation mode '{mode}'")
5854*da0073e9SAndroid Build Coastguard Worker                    for input_requires_grad in [False, True]:
5855*da0073e9SAndroid Build Coastguard Worker                        input = input.requires_grad_(input_requires_grad)
5856*da0073e9SAndroid Build Coastguard Worker                        F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
5857*da0073e9SAndroid Build Coastguard Worker                                      align_corners=align_corners).sum().backward()
5858*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0,
5859*da0073e9SAndroid Build Coastguard Worker                                         msg=f"gradient groundtruth comparison failed for mode={mode}, "
5860*da0073e9SAndroid Build Coastguard Worker                                         f"padding_mode={padding_mode}, input_requires_grad={input_requires_grad}")
5861*da0073e9SAndroid Build Coastguard Worker                        grid.grad.zero_()
5862*da0073e9SAndroid Build Coastguard Worker
5863*da0073e9SAndroid Build Coastguard Worker                    # See NOTE [ grid_sample CPU fallback ]
5864*da0073e9SAndroid Build Coastguard Worker                    torch._grid_sampler_2d_cpu_fallback(
5865*da0073e9SAndroid Build Coastguard Worker                        input.float(), grid.float(),
5866*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
5867*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
5868*da0073e9SAndroid Build Coastguard Worker                        align_corners).sum().backward()
5869*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grid.grad, groundtruth, atol=1e-5, rtol=0)
5870*da0073e9SAndroid Build Coastguard Worker
5871*da0073e9SAndroid Build Coastguard Worker                    # do gradcheck
5872*da0073e9SAndroid Build Coastguard Worker                    N = random.randint(2, 8)
5873*da0073e9SAndroid Build Coastguard Worker                    C = random.randint(2, 6)
5874*da0073e9SAndroid Build Coastguard Worker                    H = random.randint(2, 8)
5875*da0073e9SAndroid Build Coastguard Worker                    W = random.randint(2, 8)
5876*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(N, C, H, W, requires_grad=True)
5877*da0073e9SAndroid Build Coastguard Worker                    grid = torch.randn(N, H, W, 2, requires_grad=True)
5878*da0073e9SAndroid Build Coastguard Worker
5879*da0073e9SAndroid Build Coastguard Worker                    for input_requires_grad in [False, True]:
5880*da0073e9SAndroid Build Coastguard Worker                        input.requires_grad_(input_requires_grad)
5881*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(gradcheck(
5882*da0073e9SAndroid Build Coastguard Worker                            lambda inp, grd: F.grid_sample(inp, grd, mode=mode, padding_mode=padding_mode,
5883*da0073e9SAndroid Build Coastguard Worker                                                           align_corners=align_corners),
5884*da0073e9SAndroid Build Coastguard Worker                            (input, grid)))
5885*da0073e9SAndroid Build Coastguard Worker                        test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad)
5886*da0073e9SAndroid Build Coastguard Worker                        if TEST_CUDNN:
5887*da0073e9SAndroid Build Coastguard Worker                            with cudnn.flags(enabled=False):
5888*da0073e9SAndroid Build Coastguard Worker                                test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad)
5889*da0073e9SAndroid Build Coastguard Worker
5890*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
5891*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_3d(self):
5892*da0073e9SAndroid Build Coastguard Worker        # Backward pass of native C++ and CUDA kernels branch depending on whether input requires gradient,
5893*da0073e9SAndroid Build Coastguard Worker        # so we test both cases.
5894*da0073e9SAndroid Build Coastguard Worker        def test(N, C, D, H, W, mode, padding_mode, align_corners, input_requires_grad):
5895*da0073e9SAndroid Build Coastguard Worker            def test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners):
5896*da0073e9SAndroid Build Coastguard Worker                input_cpu = torch.randn(C, N, ID, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
5897*da0073e9SAndroid Build Coastguard Worker                grid_cpu = torch.randn(D, N, H, W, 3).transpose(0, 1).requires_grad_()
5898*da0073e9SAndroid Build Coastguard Worker                out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5899*da0073e9SAndroid Build Coastguard Worker                                        align_corners=align_corners)
5900*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out_cpu.size() == torch.Size([N, C, D, H, W]))
5901*da0073e9SAndroid Build Coastguard Worker
5902*da0073e9SAndroid Build Coastguard Worker                gradients = torch.randn_like(out_cpu)
5903*da0073e9SAndroid Build Coastguard Worker                out_cpu.backward(gradients)
5904*da0073e9SAndroid Build Coastguard Worker
5905*da0073e9SAndroid Build Coastguard Worker                if TEST_CUDA:
5906*da0073e9SAndroid Build Coastguard Worker                    input_cuda = input_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_(input_requires_grad)
5907*da0073e9SAndroid Build Coastguard Worker                    grid_cuda = grid_cpu.detach().transpose(0, 1).cuda().transpose(0, 1).requires_grad_()
5908*da0073e9SAndroid Build Coastguard Worker                    out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5909*da0073e9SAndroid Build Coastguard Worker                                             align_corners=align_corners)
5910*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_cpu, out_cuda)
5911*da0073e9SAndroid Build Coastguard Worker
5912*da0073e9SAndroid Build Coastguard Worker                    out_cuda.backward(gradients.cuda())
5913*da0073e9SAndroid Build Coastguard Worker                    if input_requires_grad:
5914*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(input_cpu.grad, input_cuda.grad)
5915*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grid_cpu.grad, grid_cuda.grad, atol=5e-5, rtol=0)
5916*da0073e9SAndroid Build Coastguard Worker
5917*da0073e9SAndroid Build Coastguard Worker                    # check that zero-dimensional input strides don't error out
5918*da0073e9SAndroid Build Coastguard Worker                    base_input = torch.randn(N, C, 1, IH, IW)
5919*da0073e9SAndroid Build Coastguard Worker                    input_cpu = base_input.expand_as(input_cuda).requires_grad_(input_requires_grad)
5920*da0073e9SAndroid Build Coastguard Worker                    grid_cpu = torch.randn(N, D, H, W, 3, requires_grad=True)
5921*da0073e9SAndroid Build Coastguard Worker                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
5922*da0073e9SAndroid Build Coastguard Worker                                            align_corners=align_corners)
5923*da0073e9SAndroid Build Coastguard Worker
5924*da0073e9SAndroid Build Coastguard Worker                    input_cuda = base_input.cuda().expand_as(input_cuda).requires_grad_(input_requires_grad)
5925*da0073e9SAndroid Build Coastguard Worker                    grid_cuda = grid_cpu.detach().cuda().requires_grad_()
5926*da0073e9SAndroid Build Coastguard Worker                    out_cuda = F.grid_sample(input_cuda, grid_cuda, mode=mode, padding_mode=padding_mode,
5927*da0073e9SAndroid Build Coastguard Worker                                             align_corners=align_corners)
5928*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_cpu, out_cuda)
5929*da0073e9SAndroid Build Coastguard Worker
5930*da0073e9SAndroid Build Coastguard Worker            # test same size output
5931*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, D, H, W, D, H, W, mode, padding_mode, align_corners)
5932*da0073e9SAndroid Build Coastguard Worker
5933*da0073e9SAndroid Build Coastguard Worker            # test larger output
5934*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 7)
5935*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 5)
5936*da0073e9SAndroid Build Coastguard Worker            ID = random.randint(2, 7)
5937*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 7)
5938*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 7)
5939*da0073e9SAndroid Build Coastguard Worker            D = random.randint(ID + 1, 10)
5940*da0073e9SAndroid Build Coastguard Worker            H = random.randint(IH + 1, 10)
5941*da0073e9SAndroid Build Coastguard Worker            W = random.randint(IW + 1, 10)
5942*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5943*da0073e9SAndroid Build Coastguard Worker
5944*da0073e9SAndroid Build Coastguard Worker            # test smaller output
5945*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 7)
5946*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 5)
5947*da0073e9SAndroid Build Coastguard Worker            ID = random.randint(2, 7)
5948*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 7)
5949*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 7)
5950*da0073e9SAndroid Build Coastguard Worker            D = random.randint(2, ID)
5951*da0073e9SAndroid Build Coastguard Worker            H = random.randint(2, IH)
5952*da0073e9SAndroid Build Coastguard Worker            W = random.randint(2, IW)
5953*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5954*da0073e9SAndroid Build Coastguard Worker
5955*da0073e9SAndroid Build Coastguard Worker            # test 1x1 inpput
5956*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 7)
5957*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 7)
5958*da0073e9SAndroid Build Coastguard Worker            ID = 1
5959*da0073e9SAndroid Build Coastguard Worker            IH = 1
5960*da0073e9SAndroid Build Coastguard Worker            IW = 1
5961*da0073e9SAndroid Build Coastguard Worker            H = random.randint(2, 5)
5962*da0073e9SAndroid Build Coastguard Worker            W = random.randint(2, 5)
5963*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5964*da0073e9SAndroid Build Coastguard Worker
5965*da0073e9SAndroid Build Coastguard Worker            # testing empty grid
5966*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 7)
5967*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 5)
5968*da0073e9SAndroid Build Coastguard Worker            ID = random.randint(2, 7)
5969*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 7)
5970*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 7)
5971*da0073e9SAndroid Build Coastguard Worker            D = random.randint(3, ID + 2)
5972*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
5973*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, ID, IH, IW, D, 0, W, mode, padding_mode, align_corners)
5974*da0073e9SAndroid Build Coastguard Worker
5975*da0073e9SAndroid Build Coastguard Worker            # testing empty channel
5976*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 7)
5977*da0073e9SAndroid Build Coastguard Worker            ID = random.randint(2, 5)
5978*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 7)
5979*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 7)
5980*da0073e9SAndroid Build Coastguard Worker            D = random.randint(3, ID + 2)
5981*da0073e9SAndroid Build Coastguard Worker            H = random.randint(3, IH + 2)
5982*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
5983*da0073e9SAndroid Build Coastguard Worker            test_shape(N, 0, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5984*da0073e9SAndroid Build Coastguard Worker
5985*da0073e9SAndroid Build Coastguard Worker            # testing empty batch
5986*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 5)
5987*da0073e9SAndroid Build Coastguard Worker            ID = random.randint(2, 7)
5988*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 7)
5989*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 7)
5990*da0073e9SAndroid Build Coastguard Worker            D = random.randint(3, ID + 2)
5991*da0073e9SAndroid Build Coastguard Worker            H = random.randint(3, IH + 2)
5992*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
5993*da0073e9SAndroid Build Coastguard Worker            test_shape(0, C, ID, IH, IW, D, H, W, mode, padding_mode, align_corners)
5994*da0073e9SAndroid Build Coastguard Worker
5995*da0073e9SAndroid Build Coastguard Worker        for mode in ('bilinear', 'nearest'):
5996*da0073e9SAndroid Build Coastguard Worker            for padding_mode in ('zeros', 'border', 'reflection'):
5997*da0073e9SAndroid Build Coastguard Worker                for align_corners in (True, False):
5998*da0073e9SAndroid Build Coastguard Worker                    # do gradcheck
5999*da0073e9SAndroid Build Coastguard Worker                    N = random.randint(2, 5)
6000*da0073e9SAndroid Build Coastguard Worker                    C = random.randint(2, 4)
6001*da0073e9SAndroid Build Coastguard Worker                    D = random.randint(2, 5)
6002*da0073e9SAndroid Build Coastguard Worker                    H = random.randint(2, 5)
6003*da0073e9SAndroid Build Coastguard Worker                    W = random.randint(2, 5)
6004*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(N, C, D, H, W, requires_grad=True)
6005*da0073e9SAndroid Build Coastguard Worker                    grid = torch.randn(N, D, H, W, 3, requires_grad=True)
6006*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(gradcheck(
6007*da0073e9SAndroid Build Coastguard Worker                        lambda inp, grid: F.grid_sample(inp, grid, mode=mode, padding_mode=padding_mode,
6008*da0073e9SAndroid Build Coastguard Worker                                                        align_corners=align_corners),
6009*da0073e9SAndroid Build Coastguard Worker                        (input, grid)))
6010*da0073e9SAndroid Build Coastguard Worker                    input = input.requires_grad_(False)
6011*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(gradcheck(
6012*da0073e9SAndroid Build Coastguard Worker                        lambda grid: F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
6013*da0073e9SAndroid Build Coastguard Worker                                                   align_corners=align_corners),
6014*da0073e9SAndroid Build Coastguard Worker                        (grid,)))
6015*da0073e9SAndroid Build Coastguard Worker
6016*da0073e9SAndroid Build Coastguard Worker                    for input_requires_grad in [False, True]:
6017*da0073e9SAndroid Build Coastguard Worker                        test(N, C, D, H, W, mode, padding_mode, align_corners, input_requires_grad)
6018*da0073e9SAndroid Build Coastguard Worker
6019*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_nearest_neighbor_rounding_mode_consistency(self):
6020*da0073e9SAndroid Build Coastguard Worker
6021*da0073e9SAndroid Build Coastguard Worker        device_list = ['cpu']
6022*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
6023*da0073e9SAndroid Build Coastguard Worker            device_list.append('cuda')
6024*da0073e9SAndroid Build Coastguard Worker
6025*da0073e9SAndroid Build Coastguard Worker        def normalize_indices(indices_unnormalized: torch.Tensor, dim_size: int, align_corners: bool):
6026*da0073e9SAndroid Build Coastguard Worker            if align_corners:
6027*da0073e9SAndroid Build Coastguard Worker                indices_normalized = 2 * indices_unnormalized / (dim_size - 1) - 1
6028*da0073e9SAndroid Build Coastguard Worker            else:
6029*da0073e9SAndroid Build Coastguard Worker                indices_normalized = (indices_unnormalized * 2 + 1) / dim_size - 1
6030*da0073e9SAndroid Build Coastguard Worker            return indices_normalized
6031*da0073e9SAndroid Build Coastguard Worker
6032*da0073e9SAndroid Build Coastguard Worker        test_dim_size = 10
6033*da0073e9SAndroid Build Coastguard Worker        non_test_dim_size = 9
6034*da0073e9SAndroid Build Coastguard Worker        step_size = 0.1
6035*da0073e9SAndroid Build Coastguard Worker
6036*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
6037*da0073e9SAndroid Build Coastguard Worker        channel_size = 1
6038*da0073e9SAndroid Build Coastguard Worker
6039*da0073e9SAndroid Build Coastguard Worker        mode = 'nearest'
6040*da0073e9SAndroid Build Coastguard Worker        for device in device_list:
6041*da0073e9SAndroid Build Coastguard Worker            for padding_mode in ('zeros', 'border', 'reflection'):
6042*da0073e9SAndroid Build Coastguard Worker                for align_corners in (True, False):
6043*da0073e9SAndroid Build Coastguard Worker                    # Unnormalized inquiry indices
6044*da0073e9SAndroid Build Coastguard Worker                    inquiry_indices_unnormalized = torch.arange(
6045*da0073e9SAndroid Build Coastguard Worker                        0,
6046*da0073e9SAndroid Build Coastguard Worker                        test_dim_size - 1 + step_size, step_size,
6047*da0073e9SAndroid Build Coastguard Worker                        dtype=torch.float32,
6048*da0073e9SAndroid Build Coastguard Worker                        device=device
6049*da0073e9SAndroid Build Coastguard Worker                    )
6050*da0073e9SAndroid Build Coastguard Worker                    # Note that even though we are trying to create normalized indices
6051*da0073e9SAndroid Build Coastguard Worker                    # which results in x.0 and x.5 indices after unnormalization,
6052*da0073e9SAndroid Build Coastguard Worker                    # because of the numerical error,
6053*da0073e9SAndroid Build Coastguard Worker                    # the rounding direction might not always be expected as designed.
6054*da0073e9SAndroid Build Coastguard Worker                    # The best we could do is to ensure the rounding behaviors across
6055*da0073e9SAndroid Build Coastguard Worker                    # different implementations for different dimensions are
6056*da0073e9SAndroid Build Coastguard Worker                    # exactly the same.
6057*da0073e9SAndroid Build Coastguard Worker                    inquiry_indices = normalize_indices(
6058*da0073e9SAndroid Build Coastguard Worker                        indices_unnormalized=inquiry_indices_unnormalized,
6059*da0073e9SAndroid Build Coastguard Worker                        dim_size=test_dim_size,
6060*da0073e9SAndroid Build Coastguard Worker                        align_corners=align_corners
6061*da0073e9SAndroid Build Coastguard Worker                    )
6062*da0073e9SAndroid Build Coastguard Worker                    num_inqueries = inquiry_indices.shape[0]
6063*da0073e9SAndroid Build Coastguard Worker                    inquiry_fixed_indices = torch.full((num_inqueries,), 0.5, dtype=torch.float32, device=device)
6064*da0073e9SAndroid Build Coastguard Worker                    array_data = torch.rand(test_dim_size, dtype=torch.float32, device=device)
6065*da0073e9SAndroid Build Coastguard Worker                    # 2D grid sample x-dim interpolation
6066*da0073e9SAndroid Build Coastguard Worker                    # The input_tensor_2d_x is of shape
6067*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, non_test_dim_size, test_dim_size]
6068*da0073e9SAndroid Build Coastguard Worker                    input_tensor_2d_x = array_data.reshape(1, test_dim_size).repeat(
6069*da0073e9SAndroid Build Coastguard Worker                        batch_size,
6070*da0073e9SAndroid Build Coastguard Worker                        channel_size,
6071*da0073e9SAndroid Build Coastguard Worker                        non_test_dim_size,
6072*da0073e9SAndroid Build Coastguard Worker                        1
6073*da0073e9SAndroid Build Coastguard Worker                    )
6074*da0073e9SAndroid Build Coastguard Worker                    # The grid_tensor_2d_x is of shape
6075*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, 1, num_inqueries]
6076*da0073e9SAndroid Build Coastguard Worker                    grid_tensor_2d_x = torch.cat(
6077*da0073e9SAndroid Build Coastguard Worker                        tensors=(
6078*da0073e9SAndroid Build Coastguard Worker                            inquiry_indices.reshape(num_inqueries, 1),
6079*da0073e9SAndroid Build Coastguard Worker                            inquiry_fixed_indices.reshape(num_inqueries, 1),
6080*da0073e9SAndroid Build Coastguard Worker                        ),
6081*da0073e9SAndroid Build Coastguard Worker                        dim=1
6082*da0073e9SAndroid Build Coastguard Worker                    ).repeat(batch_size, 1, 1, 1)
6083*da0073e9SAndroid Build Coastguard Worker                    # The output_tensor_2d_x is of shape
6084*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, 1, num_inqueries]
6085*da0073e9SAndroid Build Coastguard Worker                    output_tensor_2d_x = F.grid_sample(
6086*da0073e9SAndroid Build Coastguard Worker                        input=input_tensor_2d_x,
6087*da0073e9SAndroid Build Coastguard Worker                        grid=grid_tensor_2d_x,
6088*da0073e9SAndroid Build Coastguard Worker                        mode=mode,
6089*da0073e9SAndroid Build Coastguard Worker                        padding_mode=padding_mode,
6090*da0073e9SAndroid Build Coastguard Worker                        align_corners=align_corners,
6091*da0073e9SAndroid Build Coastguard Worker                    )
6092*da0073e9SAndroid Build Coastguard Worker                    # 2D grid sample y-dim interpolation
6093*da0073e9SAndroid Build Coastguard Worker                    # The input_tensor_2d_y is of shape
6094*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, test_dim_size, non_test_dim_size]
6095*da0073e9SAndroid Build Coastguard Worker                    input_tensor_2d_y = torch.transpose(input_tensor_2d_x, 3, 2)
6096*da0073e9SAndroid Build Coastguard Worker                    # The grid_tensor_2d_y is of shape
6097*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, 1, num_inqueries]
6098*da0073e9SAndroid Build Coastguard Worker                    grid_tensor_2d_y = torch.index_select(
6099*da0073e9SAndroid Build Coastguard Worker                        grid_tensor_2d_x,
6100*da0073e9SAndroid Build Coastguard Worker                        -1,
6101*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([1, 0], dtype=torch.int64, device=device)
6102*da0073e9SAndroid Build Coastguard Worker                    )
6103*da0073e9SAndroid Build Coastguard Worker                    # The output_tensor_2d_y is of shape
6104*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, 1, num_inqueries]
6105*da0073e9SAndroid Build Coastguard Worker                    output_tensor_2d_y = F.grid_sample(
6106*da0073e9SAndroid Build Coastguard Worker                        input=input_tensor_2d_y,
6107*da0073e9SAndroid Build Coastguard Worker                        grid=grid_tensor_2d_y,
6108*da0073e9SAndroid Build Coastguard Worker                        mode=mode,
6109*da0073e9SAndroid Build Coastguard Worker                        padding_mode=padding_mode,
6110*da0073e9SAndroid Build Coastguard Worker                        align_corners=align_corners,
6111*da0073e9SAndroid Build Coastguard Worker                    )
6112*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_2d_y[0, 0, 0, :], atol=0, rtol=0)
6113*da0073e9SAndroid Build Coastguard Worker                    # 3D grid sample x-dim interpolation
6114*da0073e9SAndroid Build Coastguard Worker                    # The input_tensor_3d_x is of shape
6115*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size]
6116*da0073e9SAndroid Build Coastguard Worker                    input_tensor_3d_x = array_data.reshape(1, test_dim_size).repeat(
6117*da0073e9SAndroid Build Coastguard Worker                        batch_size, channel_size, non_test_dim_size, non_test_dim_size, 1)
6118*da0073e9SAndroid Build Coastguard Worker                    # The grid_tensor_3d_x is of shape
6119*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, 1, 1, num_inqueries]
6120*da0073e9SAndroid Build Coastguard Worker                    grid_tensor_3d_x = torch.cat(
6121*da0073e9SAndroid Build Coastguard Worker                        tensors=(
6122*da0073e9SAndroid Build Coastguard Worker                            inquiry_indices.reshape(num_inqueries, 1),
6123*da0073e9SAndroid Build Coastguard Worker                            inquiry_fixed_indices.reshape(num_inqueries, 1),
6124*da0073e9SAndroid Build Coastguard Worker                            inquiry_fixed_indices.reshape(num_inqueries, 1),
6125*da0073e9SAndroid Build Coastguard Worker                        ),
6126*da0073e9SAndroid Build Coastguard Worker                        dim=1
6127*da0073e9SAndroid Build Coastguard Worker                    ).repeat(batch_size, 1, 1, 1, 1)
6128*da0073e9SAndroid Build Coastguard Worker                    # The output_tensor_3d_x is of shape
6129*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, 1, 1, num_inqueries]
6130*da0073e9SAndroid Build Coastguard Worker                    output_tensor_3d_x = F.grid_sample(
6131*da0073e9SAndroid Build Coastguard Worker                        input=input_tensor_3d_x,
6132*da0073e9SAndroid Build Coastguard Worker                        grid=grid_tensor_3d_x,
6133*da0073e9SAndroid Build Coastguard Worker                        mode=mode,
6134*da0073e9SAndroid Build Coastguard Worker                        padding_mode=padding_mode,
6135*da0073e9SAndroid Build Coastguard Worker                        align_corners=align_corners,
6136*da0073e9SAndroid Build Coastguard Worker                    )
6137*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_x[0, 0, 0, 0, :], atol=0, rtol=0)
6138*da0073e9SAndroid Build Coastguard Worker                    # 3D grid sample y-dim interpolation
6139*da0073e9SAndroid Build Coastguard Worker                    # The input_tensor_3d_y is of shape
6140*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, non_test_dim_size, test_dim_size, non_test_dim_size]
6141*da0073e9SAndroid Build Coastguard Worker                    input_tensor_3d_y = torch.transpose(input_tensor_3d_x, 4, 3)
6142*da0073e9SAndroid Build Coastguard Worker                    # The grid_tensor_3d_y is of shape
6143*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, 1, 1, num_inqueries]
6144*da0073e9SAndroid Build Coastguard Worker                    grid_tensor_3d_y = torch.index_select(
6145*da0073e9SAndroid Build Coastguard Worker                        grid_tensor_3d_x,
6146*da0073e9SAndroid Build Coastguard Worker                        -1,
6147*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([1, 0, 2], dtype=torch.int64, device=device)
6148*da0073e9SAndroid Build Coastguard Worker                    )
6149*da0073e9SAndroid Build Coastguard Worker                    # The output_tensor_3d_y is of shape
6150*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, 1, 1, num_inqueries]
6151*da0073e9SAndroid Build Coastguard Worker                    output_tensor_3d_y = F.grid_sample(
6152*da0073e9SAndroid Build Coastguard Worker                        input=input_tensor_3d_y,
6153*da0073e9SAndroid Build Coastguard Worker                        grid=grid_tensor_3d_y,
6154*da0073e9SAndroid Build Coastguard Worker                        mode=mode,
6155*da0073e9SAndroid Build Coastguard Worker                        padding_mode=padding_mode,
6156*da0073e9SAndroid Build Coastguard Worker                        align_corners=align_corners,
6157*da0073e9SAndroid Build Coastguard Worker                    )
6158*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_y[0, 0, 0, 0, :], atol=0, rtol=0)
6159*da0073e9SAndroid Build Coastguard Worker                    # 3D grid sample z-dim interpolation
6160*da0073e9SAndroid Build Coastguard Worker                    # The input_tensor_3d_z is of shape
6161*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, non_test_dim_size, non_test_dim_size, test_dim_size]
6162*da0073e9SAndroid Build Coastguard Worker                    input_tensor_3d_z = torch.transpose(input_tensor_3d_x, 4, 2)
6163*da0073e9SAndroid Build Coastguard Worker                    # The grid_tensor_3d_z is of shape
6164*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, 1, 1, num_inqueries]
6165*da0073e9SAndroid Build Coastguard Worker                    grid_tensor_3d_z = torch.index_select(
6166*da0073e9SAndroid Build Coastguard Worker                        grid_tensor_3d_x,
6167*da0073e9SAndroid Build Coastguard Worker                        -1,
6168*da0073e9SAndroid Build Coastguard Worker                        torch.tensor([1, 2, 0], dtype=torch.int64, device=device)
6169*da0073e9SAndroid Build Coastguard Worker                    )
6170*da0073e9SAndroid Build Coastguard Worker                    # The output_tensor_3d_z is of shape
6171*da0073e9SAndroid Build Coastguard Worker                    # [batch_size, channel_size, 1, 1, num_inqueries]
6172*da0073e9SAndroid Build Coastguard Worker                    output_tensor_3d_z = F.grid_sample(
6173*da0073e9SAndroid Build Coastguard Worker                        input=input_tensor_3d_z,
6174*da0073e9SAndroid Build Coastguard Worker                        grid=grid_tensor_3d_z,
6175*da0073e9SAndroid Build Coastguard Worker                        mode=mode,
6176*da0073e9SAndroid Build Coastguard Worker                        padding_mode=padding_mode,
6177*da0073e9SAndroid Build Coastguard Worker                        align_corners=align_corners,
6178*da0073e9SAndroid Build Coastguard Worker                    )
6179*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output_tensor_2d_x[0, 0, 0, :], output_tensor_3d_z[0, 0, 0, 0, :], atol=0, rtol=0)
6180*da0073e9SAndroid Build Coastguard Worker
6181*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6182*da0073e9SAndroid Build Coastguard Worker    def test_affine_grid(self):
6183*da0073e9SAndroid Build Coastguard Worker        # test known input on CPU
6184*da0073e9SAndroid Build Coastguard Worker        input = torch.arange(1., 7).view(1, 2, 3)
6185*da0073e9SAndroid Build Coastguard Worker        output = F.affine_grid(input, torch.Size([1, 1, 2, 2]), align_corners=True)
6186*da0073e9SAndroid Build Coastguard Worker        groundtruth = torch.tensor(
6187*da0073e9SAndroid Build Coastguard Worker            [[[0., -3.], [2., 5.]], [[4., 7.], [6., 15.]]]).view(1, 2, 2, 2)
6188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, groundtruth)
6189*da0073e9SAndroid Build Coastguard Worker        output = F.affine_grid(input, torch.Size([1, 1, 2, 2]), align_corners=False)
6190*da0073e9SAndroid Build Coastguard Worker        groundtruth = torch.tensor(
6191*da0073e9SAndroid Build Coastguard Worker            [[[1.5, 1.5], [2.5, 5.5]], [[3.5, 6.5], [4.5, 10.5]]]).view(1, 2, 2, 2)
6192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, groundtruth)
6193*da0073e9SAndroid Build Coastguard Worker
6194*da0073e9SAndroid Build Coastguard Worker        for align_corners in (True, False):
6195*da0073e9SAndroid Build Coastguard Worker            # do gradcheck
6196*da0073e9SAndroid Build Coastguard Worker            N = random.randint(1, 8)
6197*da0073e9SAndroid Build Coastguard Worker            C = random.randint(1, 8)
6198*da0073e9SAndroid Build Coastguard Worker            H = random.randint(1, 8)
6199*da0073e9SAndroid Build Coastguard Worker            W = random.randint(1, 8)
6200*da0073e9SAndroid Build Coastguard Worker            sz = torch.Size([N, C, H, W])
6201*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(N, 2, 3, requires_grad=True)
6202*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True):
6203*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6204*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(gradcheck(
6205*da0073e9SAndroid Build Coastguard Worker                    lambda inp: F.affine_grid(inp, sz, align_corners=align_corners),
6206*da0073e9SAndroid Build Coastguard Worker                    (inp,)))
6207*da0073e9SAndroid Build Coastguard Worker
6208*da0073e9SAndroid Build Coastguard Worker        # test CPU against CUDA
6209*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
6210*da0073e9SAndroid Build Coastguard Worker            N = random.randint(1, 8)
6211*da0073e9SAndroid Build Coastguard Worker            C = random.randint(1, 8)
6212*da0073e9SAndroid Build Coastguard Worker            H = random.randint(1, 8)
6213*da0073e9SAndroid Build Coastguard Worker            W = random.randint(1, 8)
6214*da0073e9SAndroid Build Coastguard Worker            sz = torch.Size([N, C, H, W])
6215*da0073e9SAndroid Build Coastguard Worker            for align_corners in (True, False):
6216*da0073e9SAndroid Build Coastguard Worker                input_cpu = torch.randn(N, 2, 3, requires_grad=True)
6217*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True):
6218*da0073e9SAndroid Build Coastguard Worker                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6219*da0073e9SAndroid Build Coastguard Worker                    out_cpu = F.affine_grid(input_cpu, sz, align_corners=align_corners)
6220*da0073e9SAndroid Build Coastguard Worker                gradients = torch.randn(out_cpu.size())
6221*da0073e9SAndroid Build Coastguard Worker                out_cpu.backward(gradients)
6222*da0073e9SAndroid Build Coastguard Worker                input_gpu = input_cpu.detach().cuda().requires_grad_()
6223*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True):
6224*da0073e9SAndroid Build Coastguard Worker                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6225*da0073e9SAndroid Build Coastguard Worker                    out_cuda = F.affine_grid(input_gpu, sz, align_corners=align_corners)
6226*da0073e9SAndroid Build Coastguard Worker                out_cuda.backward(gradients.cuda())
6227*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_cpu, out_cuda)
6228*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input_cpu.grad, input_gpu.grad)
6229*da0073e9SAndroid Build Coastguard Worker
6230*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6231*da0073e9SAndroid Build Coastguard Worker    def test_affine_grid_3d(self):
6232*da0073e9SAndroid Build Coastguard Worker        # test known input on CPU
6233*da0073e9SAndroid Build Coastguard Worker        input = torch.arange(1., 13).view(1, 3, 4)
6234*da0073e9SAndroid Build Coastguard Worker        output = F.affine_grid(input, torch.Size([1, 1, 2, 2, 2]), align_corners=True)
6235*da0073e9SAndroid Build Coastguard Worker        groundtruth = torch.tensor(
6236*da0073e9SAndroid Build Coastguard Worker            [[[[[-2., -10., -18.], [0., 0., 0.]], [[2., 2., 2.], [4., 12., 20.]]],
6237*da0073e9SAndroid Build Coastguard Worker              [[[4., 4., 4.], [6., 14., 22.]], [[8., 16., 24.], [10., 26., 42.]]]]]).view(1, 2, 2, 2, 3)
6238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, groundtruth)
6239*da0073e9SAndroid Build Coastguard Worker        output = F.affine_grid(input, torch.Size([1, 1, 2, 2, 2]), align_corners=False)
6240*da0073e9SAndroid Build Coastguard Worker        groundtruth = torch.tensor(
6241*da0073e9SAndroid Build Coastguard Worker            [[[[[1., -1., -3.], [2., 4., 6.]], [[3., 5., 7.], [4., 10., 16.]]],
6242*da0073e9SAndroid Build Coastguard Worker              [[[4., 6., 8.], [5., 11., 17.]], [[6., 12., 18.], [7., 17., 27.]]]]]).view(1, 2, 2, 2, 3)
6243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, groundtruth)
6244*da0073e9SAndroid Build Coastguard Worker
6245*da0073e9SAndroid Build Coastguard Worker        for align_corners in (True, False):
6246*da0073e9SAndroid Build Coastguard Worker            # do gradcheck
6247*da0073e9SAndroid Build Coastguard Worker            N = random.randint(1, 8)
6248*da0073e9SAndroid Build Coastguard Worker            C = random.randint(1, 8)
6249*da0073e9SAndroid Build Coastguard Worker            D = random.randint(1, 8)
6250*da0073e9SAndroid Build Coastguard Worker            H = random.randint(1, 8)
6251*da0073e9SAndroid Build Coastguard Worker            W = random.randint(1, 8)
6252*da0073e9SAndroid Build Coastguard Worker            sz = torch.Size([N, C, D, H, W])
6253*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(N, 3, 4, requires_grad=True)
6254*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True):
6255*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6256*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(gradcheck(
6257*da0073e9SAndroid Build Coastguard Worker                    lambda inp: F.affine_grid(inp, sz, align_corners=align_corners),
6258*da0073e9SAndroid Build Coastguard Worker                    (inp,)))
6259*da0073e9SAndroid Build Coastguard Worker
6260*da0073e9SAndroid Build Coastguard Worker        # test CPU against CUDA
6261*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
6262*da0073e9SAndroid Build Coastguard Worker            N = random.randint(1, 8)
6263*da0073e9SAndroid Build Coastguard Worker            C = random.randint(1, 8)
6264*da0073e9SAndroid Build Coastguard Worker            D = random.randint(1, 8)
6265*da0073e9SAndroid Build Coastguard Worker            H = random.randint(1, 8)
6266*da0073e9SAndroid Build Coastguard Worker            W = random.randint(1, 8)
6267*da0073e9SAndroid Build Coastguard Worker            sz = torch.Size([N, C, D, H, W])
6268*da0073e9SAndroid Build Coastguard Worker            for align_corners in (True, False):
6269*da0073e9SAndroid Build Coastguard Worker                input_cpu = torch.randn(N, 3, 4, requires_grad=True)
6270*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True):
6271*da0073e9SAndroid Build Coastguard Worker                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6272*da0073e9SAndroid Build Coastguard Worker                    out_cpu = F.affine_grid(input_cpu, sz, align_corners=align_corners)
6273*da0073e9SAndroid Build Coastguard Worker                gradients = torch.randn(out_cpu.size())
6274*da0073e9SAndroid Build Coastguard Worker                out_cpu.backward(gradients)
6275*da0073e9SAndroid Build Coastguard Worker                input_gpu = input_cpu.detach().cuda().requires_grad_()
6276*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True):
6277*da0073e9SAndroid Build Coastguard Worker                    warnings.simplefilter("always")  # python2 requires this so other tests can trigger
6278*da0073e9SAndroid Build Coastguard Worker                    out_cuda = F.affine_grid(input_gpu, sz, align_corners=align_corners)
6279*da0073e9SAndroid Build Coastguard Worker                out_cuda.backward(gradients.cuda())
6280*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_cpu, out_cuda)
6281*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input_cpu.grad, input_gpu.grad)
6282*da0073e9SAndroid Build Coastguard Worker
6283*da0073e9SAndroid Build Coastguard Worker    def test_channel_shuffle_return_alias_of_self(self):
6284*da0073e9SAndroid Build Coastguard Worker        # gh-76616: nn.ChannelShuffle will return alias of self with an empty input tensor
6285*da0073e9SAndroid Build Coastguard Worker        groups = 3
6286*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.rand([0, 9, 4, 4])
6287*da0073e9SAndroid Build Coastguard Worker        output = torch.nn.ChannelShuffle(groups)(input_tensor)
6288*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output, input_tensor)
6289*da0073e9SAndroid Build Coastguard Worker
6290*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
6291*da0073e9SAndroid Build Coastguard Worker    def test_native_channel_shuffle_return_alias_of_self(self):
6292*da0073e9SAndroid Build Coastguard Worker        groups = 3
6293*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.rand([0, 9, 4, 4])
6294*da0073e9SAndroid Build Coastguard Worker        output = torch.native_channel_shuffle(input_tensor, groups)
6295*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output, input_tensor)
6296*da0073e9SAndroid Build Coastguard Worker
6297*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6298*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingLinear1d(self):
6299*da0073e9SAndroid Build Coastguard Worker        for align_corners in [True, False]:
6300*da0073e9SAndroid Build Coastguard Worker            for recompute_scale_factor in [True, False]:
6301*da0073e9SAndroid Build Coastguard Worker                kwargs = dict(
6302*da0073e9SAndroid Build Coastguard Worker                    mode='linear', align_corners=align_corners, recompute_scale_factor=recompute_scale_factor
6303*da0073e9SAndroid Build Coastguard Worker                )
6304*da0073e9SAndroid Build Coastguard Worker                # test float scale factor up & downsampling
6305*da0073e9SAndroid Build Coastguard Worker                for scale_factor in [0.5, 1.5, 2]:
6306*da0073e9SAndroid Build Coastguard Worker                    m = nn.Upsample(scale_factor=scale_factor, **kwargs)
6307*da0073e9SAndroid Build Coastguard Worker                    in_t = torch.ones(1, 1, 2)
6308*da0073e9SAndroid Build Coastguard Worker                    out_size = int(math.floor(in_t.shape[-1] * scale_factor))
6309*da0073e9SAndroid Build Coastguard Worker                    with warnings.catch_warnings(record=True) as w:
6310*da0073e9SAndroid Build Coastguard Worker                        out_t = m(in_t)
6311*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch.ones(1, 1, out_size), out_t.data)
6312*da0073e9SAndroid Build Coastguard Worker
6313*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(1, 1, 2, requires_grad=True)
6314*da0073e9SAndroid Build Coastguard Worker                    if not recompute_scale_factor:
6315*da0073e9SAndroid Build Coastguard Worker                        gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), (input,))
6316*da0073e9SAndroid Build Coastguard Worker                    else:
6317*da0073e9SAndroid Build Coastguard Worker                        gradcheck(lambda x: F.interpolate(x, scale_factor=scale_factor, **kwargs), (input,))
6318*da0073e9SAndroid Build Coastguard Worker
6319*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingLinear1d_spatial_invariance(self):
6320*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=3, mode='linear', align_corners=False)
6321*da0073e9SAndroid Build Coastguard Worker        in_t_9 = torch.zeros(1, 1, 9)
6322*da0073e9SAndroid Build Coastguard Worker        in_t_9[:, :, :4].normal_()
6323*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
6324*da0073e9SAndroid Build Coastguard Worker            out_t_9 = m(in_t_9)
6325*da0073e9SAndroid Build Coastguard Worker            out_t_5 = m(in_t_9[:, :, :5])
6326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t_9[:, :, :15], out_t_5)
6327*da0073e9SAndroid Build Coastguard Worker
6328*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6329*da0073e9SAndroid Build Coastguard Worker    def test_upsampling_not_recompute_scale_factor(self):
6330*da0073e9SAndroid Build Coastguard Worker        # test output against known input: result must match opencv
6331*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(8.).view(1, 2, 2, 2)
6332*da0073e9SAndroid Build Coastguard Worker        expected_out_t = torch.tensor(
6333*da0073e9SAndroid Build Coastguard Worker            [[[[-0.32725, -0.08843, 0.37933, 0.79744],
6334*da0073e9SAndroid Build Coastguard Worker              [0.15039, 0.38921, 0.85697, 1.27508],
6335*da0073e9SAndroid Build Coastguard Worker              [1.08591, 1.32473, 1.79249, 2.21060],
6336*da0073e9SAndroid Build Coastguard Worker              [1.92213, 2.16095, 2.62871, 3.04682]],
6337*da0073e9SAndroid Build Coastguard Worker
6338*da0073e9SAndroid Build Coastguard Worker             [[3.67275, 3.91157, 4.37933, 4.79744],
6339*da0073e9SAndroid Build Coastguard Worker              [4.15039, 4.38921, 4.85697, 5.27508],
6340*da0073e9SAndroid Build Coastguard Worker              [5.08591, 5.32473, 5.79249, 6.21060],
6341*da0073e9SAndroid Build Coastguard Worker              [5.92213, 6.16095, 6.62871, 7.04682]]]])
6342*da0073e9SAndroid Build Coastguard Worker        if IS_PPC:
6343*da0073e9SAndroid Build Coastguard Worker            # Both OpenCV and PyTorch give a slightly different result on PPC
6344*da0073e9SAndroid Build Coastguard Worker            expected_out_t = torch.tensor(
6345*da0073e9SAndroid Build Coastguard Worker                [[[[-0.32725, -0.08843, 0.37933, 0.79744],
6346*da0073e9SAndroid Build Coastguard Worker                  [0.15039, 0.38921, 0.85697, 1.27508],
6347*da0073e9SAndroid Build Coastguard Worker                  [1.08591, 1.32473, 1.79249, 2.21060],
6348*da0073e9SAndroid Build Coastguard Worker                  [1.92212, 2.16094, 2.62870, 3.04681]],
6349*da0073e9SAndroid Build Coastguard Worker
6350*da0073e9SAndroid Build Coastguard Worker                 [[3.67275, 3.91157, 4.37933, 4.79743],
6351*da0073e9SAndroid Build Coastguard Worker                  [4.15039, 4.38921, 4.85697, 5.27508],
6352*da0073e9SAndroid Build Coastguard Worker                  [5.08591, 5.32473, 5.79249, 6.21059],
6353*da0073e9SAndroid Build Coastguard Worker                  [5.92212, 6.16094, 6.62870, 7.04680]]]])
6354*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(in_t, scale_factor=2.3, mode='bicubic', align_corners=False, recompute_scale_factor=False)
6355*da0073e9SAndroid Build Coastguard Worker        torch.set_printoptions(precision=5)
6356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out_t, atol=1e-4, rtol=0)
6357*da0073e9SAndroid Build Coastguard Worker
6358*da0073e9SAndroid Build Coastguard Worker        device_list = ['cpu']
6359*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
6360*da0073e9SAndroid Build Coastguard Worker            device_list.append('cuda')
6361*da0073e9SAndroid Build Coastguard Worker
6362*da0073e9SAndroid Build Coastguard Worker        for align_corners in [True, False]:
6363*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(mode='bicubic', align_corners=align_corners)
6364*da0073e9SAndroid Build Coastguard Worker            # test float scale factor up & downsampling
6365*da0073e9SAndroid Build Coastguard Worker            for device in device_list:
6366*da0073e9SAndroid Build Coastguard Worker                for scale_factor in [0.6, 1.6, 2.3]:
6367*da0073e9SAndroid Build Coastguard Worker                    in_t = torch.ones(2, 2, 2, 2).to(device)
6368*da0073e9SAndroid Build Coastguard Worker                    out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
6369*da0073e9SAndroid Build Coastguard Worker                    out_size = int(math.floor(in_t.shape[-1] * scale_factor))
6370*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data, atol=1e-5, rtol=0)
6371*da0073e9SAndroid Build Coastguard Worker
6372*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(2, 2, 2, 2, requires_grad=True)
6373*da0073e9SAndroid Build Coastguard Worker                    gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
6374*da0073e9SAndroid Build Coastguard Worker
6375*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBilinear2d_spatial_invariance(self):
6376*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=False)
6377*da0073e9SAndroid Build Coastguard Worker        in_t_9 = torch.zeros(1, 1, 9, 9)
6378*da0073e9SAndroid Build Coastguard Worker        in_t_9[:, :, :4, :4].normal_()
6379*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
6380*da0073e9SAndroid Build Coastguard Worker            out_t_9 = m(in_t_9)
6381*da0073e9SAndroid Build Coastguard Worker            out_t_5 = m(in_t_9[:, :, :5, :5])
6382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t_9[:, :, :15, :15], out_t_5)
6383*da0073e9SAndroid Build Coastguard Worker
6384*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingTrilinear3d_spatial_invariance(self):
6385*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=3, mode='trilinear', align_corners=False)
6386*da0073e9SAndroid Build Coastguard Worker        in_t_9 = torch.zeros(1, 1, 9, 9, 9)
6387*da0073e9SAndroid Build Coastguard Worker        in_t_9[:, :, :4, :4, :4].normal_()
6388*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
6389*da0073e9SAndroid Build Coastguard Worker            out_t_9 = m(in_t_9)
6390*da0073e9SAndroid Build Coastguard Worker            out_t_5 = m(in_t_9[:, :, :5, :5, :5])
6391*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t_9[:, :, :15, :15, :15], out_t_5)
6392*da0073e9SAndroid Build Coastguard Worker
6393*da0073e9SAndroid Build Coastguard Worker    def test_upsampling_small_scale(self):
6394*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Upsample(scale_factor=0.5, mode="bilinear")
6395*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(1, 5, dtype=torch.get_default_dtype()).reshape(1, 1, 2, 2)
6396*da0073e9SAndroid Build Coastguard Worker        out_t = m(in_t)
6397*da0073e9SAndroid Build Coastguard Worker        expected_out_t = torch.tensor([[[[2.5]]]])
6398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_out_t, out_t)
6399*da0073e9SAndroid Build Coastguard Worker
6400*da0073e9SAndroid Build Coastguard Worker    def test_upsampling_bfloat16(self, dtype=torch.bfloat16):
6401*da0073e9SAndroid Build Coastguard Worker        def helper(size, scale_factor, mode, device, memory_format=torch.contiguous_format):
6402*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(size, device=device, dtype=dtype).to(memory_format=memory_format).detach().requires_grad_(True)
6403*da0073e9SAndroid Build Coastguard Worker            inputf = input.to(torch.float32).to(memory_format=torch.contiguous_format).detach().requires_grad_(True)
6404*da0073e9SAndroid Build Coastguard Worker            m = nn.Upsample(scale_factor=scale_factor, mode=mode)
6405*da0073e9SAndroid Build Coastguard Worker
6406*da0073e9SAndroid Build Coastguard Worker            outf = m(inputf)
6407*da0073e9SAndroid Build Coastguard Worker            out = m(input)
6408*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.to(torch.float32), outf, atol=0.05, rtol=0)
6409*da0073e9SAndroid Build Coastguard Worker
6410*da0073e9SAndroid Build Coastguard Worker            ginput = torch.randn(out.shape, device=device, dtype=dtype).to(memory_format=memory_format)
6411*da0073e9SAndroid Build Coastguard Worker            ginputf = ginput.to(torch.float32).to(memory_format=torch.contiguous_format)
6412*da0073e9SAndroid Build Coastguard Worker            out.backward(ginput)
6413*da0073e9SAndroid Build Coastguard Worker            outf.backward(ginputf)
6414*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.to(torch.float32), inputf.grad, atol=0.01, rtol=0.01)
6415*da0073e9SAndroid Build Coastguard Worker
6416*da0073e9SAndroid Build Coastguard Worker        for device in ['cpu']:
6417*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7], 2, 'nearest', device)
6418*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7], 2, 'nearest', device, torch.channels_last)
6419*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7, 3], 2, 'nearest', device)
6420*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 30], 2, 'linear', device)
6421*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7], 2, 'bilinear', device)
6422*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7], 2, 'bilinear', device, torch.channels_last)
6423*da0073e9SAndroid Build Coastguard Worker            helper([1, 3, 11, 7], 2, 'bicubic', device)
6424*da0073e9SAndroid Build Coastguard Worker            helper([1, 3, 11, 7], 2, 'bicubic', device, torch.channels_last)
6425*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7, 3], 2, 'trilinear', device)
6426*da0073e9SAndroid Build Coastguard Worker
6427*da0073e9SAndroid Build Coastguard Worker            helper([3, 5, 5], 257., 'nearest', device)
6428*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7], 20, 'nearest', device)
6429*da0073e9SAndroid Build Coastguard Worker            helper([3, 20, 11, 7, 3], 20, 'nearest', device)
6430*da0073e9SAndroid Build Coastguard Worker            helper([1, 2, 11, 7], 257, 'nearest', device, torch.channels_last)
6431*da0073e9SAndroid Build Coastguard Worker            helper([1, 2, 2000, 2000], 1 / 377., 'nearest', device)
6432*da0073e9SAndroid Build Coastguard Worker            helper([1, 2, 2000, 2000], 1 / 257., 'nearest', device, torch.channels_last)
6433*da0073e9SAndroid Build Coastguard Worker            helper([3, 2, 11, 7, 3], 20, 'nearest', device, torch.channels_last_3d)
6434*da0073e9SAndroid Build Coastguard Worker            helper([3, 5, 5], 10, 'linear', device)
6435*da0073e9SAndroid Build Coastguard Worker            helper([3, 5, 5], 257, 'linear', device)
6436*da0073e9SAndroid Build Coastguard Worker            helper([1, 2, 11, 7], 257, 'bilinear', device)
6437*da0073e9SAndroid Build Coastguard Worker            helper([1, 2, 11, 7], 257, 'bilinear', device, torch.channels_last)
6438*da0073e9SAndroid Build Coastguard Worker            helper([1, 3, 11, 7], 10, 'bicubic', device)
6439*da0073e9SAndroid Build Coastguard Worker            helper([1, 3, 11, 7], 10, 'bicubic', device, torch.channels_last)
6440*da0073e9SAndroid Build Coastguard Worker            helper([1, 1, 11, 7], 257, 'bicubic', device)
6441*da0073e9SAndroid Build Coastguard Worker            helper([3, 2, 11, 7, 3], 20, 'trilinear', device)
6442*da0073e9SAndroid Build Coastguard Worker            helper([3, 2, 11, 7, 3], 20, 'trilinear', device, torch.channels_last_3d)
6443*da0073e9SAndroid Build Coastguard Worker
6444*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
6445*da0073e9SAndroid Build Coastguard Worker    def test_interpolate_illegal_memory_access(self):
6446*da0073e9SAndroid Build Coastguard Worker        in_s = 45
6447*da0073e9SAndroid Build Coastguard Worker        out_s = 14
6448*da0073e9SAndroid Build Coastguard Worker
6449*da0073e9SAndroid Build Coastguard Worker        input = torch.ones((1, 1, in_s), device='cuda', requires_grad=True)
6450*da0073e9SAndroid Build Coastguard Worker        # note we allocated grad_output to be larger so out of bound access
6451*da0073e9SAndroid Build Coastguard Worker        # would be visible in grad_input
6452*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones((1, 1, out_s * 2), device='cuda', requires_grad=True)
6453*da0073e9SAndroid Build Coastguard Worker        grad = grad[:, :, :out_s]
6454*da0073e9SAndroid Build Coastguard Worker
6455*da0073e9SAndroid Build Coastguard Worker        input_ref = input.detach().cpu().requires_grad_()
6456*da0073e9SAndroid Build Coastguard Worker        grad_ref = grad.cpu()
6457*da0073e9SAndroid Build Coastguard Worker
6458*da0073e9SAndroid Build Coastguard Worker        out = F.interpolate(input, size=(out_s,), mode='nearest')
6459*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
6460*da0073e9SAndroid Build Coastguard Worker
6461*da0073e9SAndroid Build Coastguard Worker        out_ref = F.interpolate(input_ref, size=(out_s,), mode='nearest')
6462*da0073e9SAndroid Build Coastguard Worker        out_ref.backward(grad_ref)
6463*da0073e9SAndroid Build Coastguard Worker
6464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
6465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input_ref.grad, input.grad)
6466*da0073e9SAndroid Build Coastguard Worker
6467*da0073e9SAndroid Build Coastguard Worker    def test_interpolate_undefined_behavior_casting(self):
6468*da0073e9SAndroid Build Coastguard Worker        x = torch.ones([1, 1, 16, 16])
6469*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.interpolate(x, scale_factor=-1e20, mode="bilinear"))
6470*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.interpolate(x, scale_factor=1e20, mode="bilinear"))
6471*da0073e9SAndroid Build Coastguard Worker
6472*da0073e9SAndroid Build Coastguard Worker    def test_interpolate_buffer_overflow(self):
6473*da0073e9SAndroid Build Coastguard Worker        # Test buffer overflow issue due to inaccurate floating point
6474*da0073e9SAndroid Build Coastguard Worker        # representation for integer values. See issue below for details.
6475*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/88939
6476*da0073e9SAndroid Build Coastguard Worker
6477*da0073e9SAndroid Build Coastguard Worker        def helper(size, dtype, mode, device, is_channels_last):
6478*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(size, dtype=dtype, device=device)
6479*da0073e9SAndroid Build Coastguard Worker            if is_channels_last:
6480*da0073e9SAndroid Build Coastguard Worker                if len(size) == 3:
6481*da0073e9SAndroid Build Coastguard Worker                    input = input.transpose(1, 2).contiguous().transpose(1, 2)
6482*da0073e9SAndroid Build Coastguard Worker                elif len(size) == 4:
6483*da0073e9SAndroid Build Coastguard Worker                    input = input.to(memory_format=torch.channels_last)
6484*da0073e9SAndroid Build Coastguard Worker                else:
6485*da0073e9SAndroid Build Coastguard Worker                    input = input.to(memory_format=torch.channels_last_3d)
6486*da0073e9SAndroid Build Coastguard Worker            output1 = F.interpolate(input, 2, mode=mode, align_corners=True)
6487*da0073e9SAndroid Build Coastguard Worker            # reset the corner value and expect the output is changed as well
6488*da0073e9SAndroid Build Coastguard Worker            # the output won't be changed on buffer overflow
6489*da0073e9SAndroid Build Coastguard Worker            input[(-1,) * len(size)] = 0.5
6490*da0073e9SAndroid Build Coastguard Worker            output2 = F.interpolate(input, 2, mode=mode, align_corners=True)
6491*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(output1, output2)
6492*da0073e9SAndroid Build Coastguard Worker
6493*da0073e9SAndroid Build Coastguard Worker        size_dtype_list = []
6494*da0073e9SAndroid Build Coastguard Worker        # We set the size larger than the floating point exactly representable range
6495*da0073e9SAndroid Build Coastguard Worker        # float: exact representable range (-2**24,2**24)
6496*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2**24 + 4], torch.float))
6497*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2, 2**24 + 4], torch.float))
6498*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2, 2, 2**24 + 4], torch.float))
6499*da0073e9SAndroid Build Coastguard Worker        # bfloat16: exact representable range (-2**8, 2**8)
6500*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2**8 + 4], torch.bfloat16))
6501*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2, 2**8 + 4], torch.bfloat16))
6502*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2, 2, 2**8 + 4], torch.bfloat16))
6503*da0073e9SAndroid Build Coastguard Worker        # half: exact representable range (-2**11, 2**11)
6504*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2**11 + 4], torch.half))
6505*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2, 2**11 + 4], torch.half))
6506*da0073e9SAndroid Build Coastguard Worker        size_dtype_list.append(([1, 10, 2, 2, 2**11 + 4], torch.half))
6507*da0073e9SAndroid Build Coastguard Worker
6508*da0073e9SAndroid Build Coastguard Worker        # TODO: turn on cuda test after buffer overflow issue is fixed in cuda kernel
6509*da0073e9SAndroid Build Coastguard Worker        # devices = ['cpu'] + (['cuda'] if torch.cuda.is_available() else [])
6510*da0073e9SAndroid Build Coastguard Worker        devices = ['cpu']
6511*da0073e9SAndroid Build Coastguard Worker
6512*da0073e9SAndroid Build Coastguard Worker        for mode in ('linear', 'bilinear', 'bicubic', 'trilinear'):
6513*da0073e9SAndroid Build Coastguard Worker            for size_dtype in size_dtype_list:
6514*da0073e9SAndroid Build Coastguard Worker                size, dtype = size_dtype
6515*da0073e9SAndroid Build Coastguard Worker                if (
6516*da0073e9SAndroid Build Coastguard Worker                    mode == 'linear' and len(size) != 3
6517*da0073e9SAndroid Build Coastguard Worker                    or (mode == 'bilinear' and len(size) != 4)
6518*da0073e9SAndroid Build Coastguard Worker                    or (mode == 'bicubic' and len(size) != 4)
6519*da0073e9SAndroid Build Coastguard Worker                    or (mode == 'trilinear' and len(size) != 5)
6520*da0073e9SAndroid Build Coastguard Worker                ):
6521*da0073e9SAndroid Build Coastguard Worker                    continue
6522*da0073e9SAndroid Build Coastguard Worker                for device in devices:
6523*da0073e9SAndroid Build Coastguard Worker                    if (
6524*da0073e9SAndroid Build Coastguard Worker                        device == 'cpu' and dtype == torch.half
6525*da0073e9SAndroid Build Coastguard Worker                        or (device == 'cuda' and dtype == torch.bfloat16)
6526*da0073e9SAndroid Build Coastguard Worker                    ):
6527*da0073e9SAndroid Build Coastguard Worker                        # no half precision support on cpu or bfloat16 on cuda yet
6528*da0073e9SAndroid Build Coastguard Worker                        continue
6529*da0073e9SAndroid Build Coastguard Worker                    for is_channels_last in (True, False):
6530*da0073e9SAndroid Build Coastguard Worker                        helper(size, dtype, mode, device, is_channels_last)
6531*da0073e9SAndroid Build Coastguard Worker
6532*da0073e9SAndroid Build Coastguard Worker
6533*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
6534*da0073e9SAndroid Build Coastguard Worker    def test_interpolate(self):
6535*da0073e9SAndroid Build Coastguard Worker        def _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs):
6536*da0073e9SAndroid Build Coastguard Worker            test_sizes = [float(out_size),
6537*da0073e9SAndroid Build Coastguard Worker                          torch.tensor(out_size, dtype=torch.float)]
6538*da0073e9SAndroid Build Coastguard Worker            for size in test_sizes:
6539*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(TypeError,
6540*da0073e9SAndroid Build Coastguard Worker                                       "(expected size to be one of int or).*",
6541*da0073e9SAndroid Build Coastguard Worker                                       F.interpolate, in_t, size=(size,) * dim, **kwargs)
6542*da0073e9SAndroid Build Coastguard Worker
6543*da0073e9SAndroid Build Coastguard Worker        def _test_interpolate_helper(in_t, scale_factor, layer):
6544*da0073e9SAndroid Build Coastguard Worker            out_size = int(math.floor(in_t.shape[-1] * scale_factor))
6545*da0073e9SAndroid Build Coastguard Worker            dim = len(in_t.shape) - 2
6546*da0073e9SAndroid Build Coastguard Worker            out_shape = [1, 1] + [out_size] * dim
6547*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
6548*da0073e9SAndroid Build Coastguard Worker                out_t = layer(in_t)
6549*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.ones(out_shape), out_t)
6550*da0073e9SAndroid Build Coastguard Worker
6551*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6552*da0073e9SAndroid Build Coastguard Worker                F.interpolate(in_t, (out_size,) * dim, **kwargs),
6553*da0073e9SAndroid Build Coastguard Worker                F.interpolate(in_t, scale_factor=scale_factor, **kwargs))
6554*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL)
6555*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t], nondet_tol=GRADCHECK_NONDET_TOL)
6556*da0073e9SAndroid Build Coastguard Worker            _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs)
6557*da0073e9SAndroid Build Coastguard Worker
6558*da0073e9SAndroid Build Coastguard Worker        def _make_input(dim, device):
6559*da0073e9SAndroid Build Coastguard Worker            size = [1, 1]
6560*da0073e9SAndroid Build Coastguard Worker            size += [2] * dim
6561*da0073e9SAndroid Build Coastguard Worker            return torch.ones(size, requires_grad=True, device=device)
6562*da0073e9SAndroid Build Coastguard Worker
6563*da0073e9SAndroid Build Coastguard Worker        device_list = ['cpu']
6564*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
6565*da0073e9SAndroid Build Coastguard Worker            device_list.append('cuda')
6566*da0073e9SAndroid Build Coastguard Worker
6567*da0073e9SAndroid Build Coastguard Worker        for device in device_list:
6568*da0073e9SAndroid Build Coastguard Worker            for scale_factor in [0.5, 1.5, 2]:
6569*da0073e9SAndroid Build Coastguard Worker                for mode in ['nearest', 'area']:
6570*da0073e9SAndroid Build Coastguard Worker                    kwargs = dict(mode=mode)
6571*da0073e9SAndroid Build Coastguard Worker                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6572*da0073e9SAndroid Build Coastguard Worker                    for input in [_make_input(1, device), _make_input(2, device), _make_input(3, device)]:
6573*da0073e9SAndroid Build Coastguard Worker                        _test_interpolate_helper(input, scale_factor, m)
6574*da0073e9SAndroid Build Coastguard Worker
6575*da0073e9SAndroid Build Coastguard Worker                for align_corners in [True, False]:
6576*da0073e9SAndroid Build Coastguard Worker                    kwargs = dict(mode='linear', align_corners=align_corners)
6577*da0073e9SAndroid Build Coastguard Worker                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6578*da0073e9SAndroid Build Coastguard Worker                    _test_interpolate_helper(_make_input(1, device), scale_factor, m)
6579*da0073e9SAndroid Build Coastguard Worker
6580*da0073e9SAndroid Build Coastguard Worker                    kwargs = dict(mode='bilinear', align_corners=align_corners)
6581*da0073e9SAndroid Build Coastguard Worker                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6582*da0073e9SAndroid Build Coastguard Worker                    _test_interpolate_helper(_make_input(2, device), scale_factor, m)
6583*da0073e9SAndroid Build Coastguard Worker
6584*da0073e9SAndroid Build Coastguard Worker                    kwargs = dict(mode='bicubic', align_corners=align_corners)
6585*da0073e9SAndroid Build Coastguard Worker
6586*da0073e9SAndroid Build Coastguard Worker                    def m(t):
6587*da0073e9SAndroid Build Coastguard Worker                        return F.interpolate(t, scale_factor=scale_factor, **kwargs).to(device)
6588*da0073e9SAndroid Build Coastguard Worker                    _test_interpolate_helper(_make_input(2, device), scale_factor, m)
6589*da0073e9SAndroid Build Coastguard Worker
6590*da0073e9SAndroid Build Coastguard Worker                    kwargs = dict(mode='trilinear', align_corners=align_corners)
6591*da0073e9SAndroid Build Coastguard Worker                    m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
6592*da0073e9SAndroid Build Coastguard Worker                    _test_interpolate_helper(_make_input(3, device), scale_factor, m)
6593*da0073e9SAndroid Build Coastguard Worker
6594*da0073e9SAndroid Build Coastguard Worker    def test_linear_broadcasting(self):
6595*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(5, 8)
6596*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 3, 5)
6597*da0073e9SAndroid Build Coastguard Worker        expected = m(inp.view(6, 5)).view(2, 3, 8)
6598*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, m(inp))
6599*da0073e9SAndroid Build Coastguard Worker
6600*da0073e9SAndroid Build Coastguard Worker    def test_linear_raise_on_scalar_input(self):
6601*da0073e9SAndroid Build Coastguard Worker        # This used to cause an int underflow issue when reshaping the input
6602*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/119161
6603*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(1, 1)
6604*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(1).squeeze()
6605*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, ".*both arguments.*1D.*"):
6606*da0073e9SAndroid Build Coastguard Worker            m(inp)
6607*da0073e9SAndroid Build Coastguard Worker
6608*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('device', ['cpu'] + (['cuda'] if TEST_CUDA else []))
6609*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('bias', [
6610*da0073e9SAndroid Build Coastguard Worker        subtest(False, name='nobias'), subtest(True, name='bias')])
6611*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('weight_layout', [
6612*da0073e9SAndroid Build Coastguard Worker        subtest(torch.strided, name='weightStrided'),
6613*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_coo, name='weightCOO'),
6614*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_csr, name='weightCSR'),
6615*da0073e9SAndroid Build Coastguard Worker        subtest(torch.sparse_csc, name='weightCSC'),
6616*da0073e9SAndroid Build Coastguard Worker        # TODO: addmm: computation on CPU is not implemented for Strided + Strided @ SparseBsr
6617*da0073e9SAndroid Build Coastguard Worker        # subtest(torch.sparse_bsr, name='weightBSR'),
6618*da0073e9SAndroid Build Coastguard Worker        # subtest(torch.sparse_bsc, name='weightBSC'),
6619*da0073e9SAndroid Build Coastguard Worker    ])
6620*da0073e9SAndroid Build Coastguard Worker    def test_linear_autograd(self, device, bias, weight_layout):
6621*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(4, 4, bias=bias, device=device)
6622*da0073e9SAndroid Build Coastguard Worker        if weight_layout == torch.strided:
6623*da0073e9SAndroid Build Coastguard Worker            pass
6624*da0073e9SAndroid Build Coastguard Worker        elif weight_layout == torch.sparse_csr:
6625*da0073e9SAndroid Build Coastguard Worker            module.weight = nn.Parameter(module.weight.to_sparse_csr())
6626*da0073e9SAndroid Build Coastguard Worker        elif weight_layout == torch.sparse_csc:
6627*da0073e9SAndroid Build Coastguard Worker            module.weight = nn.Parameter(module.weight.to_sparse_csc())
6628*da0073e9SAndroid Build Coastguard Worker        elif weight_layout == torch.sparse_bsr:
6629*da0073e9SAndroid Build Coastguard Worker            module.weight = nn.Parameter(module.weight.to_sparse_bsr((2, 2)))
6630*da0073e9SAndroid Build Coastguard Worker        elif weight_layout == torch.sparse_bsc:
6631*da0073e9SAndroid Build Coastguard Worker            module.weight = nn.Parameter(module.weight.to_sparse_bsc((2, 2)))
6632*da0073e9SAndroid Build Coastguard Worker        elif weight_layout == torch.sparse_coo:
6633*da0073e9SAndroid Build Coastguard Worker            module.weight = nn.Parameter(module.weight.to_sparse_coo())
6634*da0073e9SAndroid Build Coastguard Worker        else:
6635*da0073e9SAndroid Build Coastguard Worker            raise AssertionError
6636*da0073e9SAndroid Build Coastguard Worker
6637*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(4, requires_grad=True, device=device)
6638*da0073e9SAndroid Build Coastguard Worker        res = module(inp)
6639*da0073e9SAndroid Build Coastguard Worker        if bias:
6640*da0073e9SAndroid Build Coastguard Worker            expected = (torch.einsum("i,ji->j", inp, module.weight.to_dense())) + module.bias
6641*da0073e9SAndroid Build Coastguard Worker        else:
6642*da0073e9SAndroid Build Coastguard Worker            expected = (torch.einsum("i,ji->j", inp, module.weight.to_dense()))
6643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
6644*da0073e9SAndroid Build Coastguard Worker
6645*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(4, device=device)
6646*da0073e9SAndroid Build Coastguard Worker        grads = torch.autograd.grad(res, [module.weight, inp], grad_output)
6647*da0073e9SAndroid Build Coastguard Worker        grads_expected = torch.autograd.grad(expected, [module.weight, inp], grad_output)
6648*da0073e9SAndroid Build Coastguard Worker
6649*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grads_expected[0].layout, weight_layout)
6650*da0073e9SAndroid Build Coastguard Worker
6651*da0073e9SAndroid Build Coastguard Worker        for g, ge in zip(grads, grads_expected):
6652*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(g, ge)
6653*da0073e9SAndroid Build Coastguard Worker
6654*da0073e9SAndroid Build Coastguard Worker    def test_bilinear(self):
6655*da0073e9SAndroid Build Coastguard Worker        module = nn.Bilinear(10, 10, 8)
6656*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(4, 10, requires_grad=True)
6657*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(4, 10, requires_grad=True)
6658*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(4, 8)
6659*da0073e9SAndroid Build Coastguard Worker        res = module(input1, input2)
6660*da0073e9SAndroid Build Coastguard Worker        expected = (torch.einsum("bi,kij,bj->bk", input1, module.weight, input2) +
6661*da0073e9SAndroid Build Coastguard Worker                    module.bias)
6662*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
6663*da0073e9SAndroid Build Coastguard Worker        grads = torch.autograd.grad(res, [module.weight, module.bias, input1, input2], grad_output)
6664*da0073e9SAndroid Build Coastguard Worker        grads_expected = torch.autograd.grad(expected, [module.weight, module.bias, input1, input2], grad_output)
6665*da0073e9SAndroid Build Coastguard Worker        for g, ge in zip(grads, grads_expected):
6666*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(g, ge)
6667*da0073e9SAndroid Build Coastguard Worker
6668*da0073e9SAndroid Build Coastguard Worker    def test_bilinear_non_contiguous(self):
6669*da0073e9SAndroid Build Coastguard Worker        module = nn.Bilinear(7, 7, 5)
6670*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(4, 7, 10, requires_grad=True)
6671*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(4, 7, 10, requires_grad=True)
6672*da0073e9SAndroid Build Coastguard Worker        input1_tp = input1.transpose(1, 2)
6673*da0073e9SAndroid Build Coastguard Worker        input2_tp = input2.transpose(1, 2)
6674*da0073e9SAndroid Build Coastguard Worker
6675*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(4, 10, 5)
6676*da0073e9SAndroid Build Coastguard Worker
6677*da0073e9SAndroid Build Coastguard Worker        def run(input1_tp, input2_tp):
6678*da0073e9SAndroid Build Coastguard Worker            input1.grad = input2.grad = None
6679*da0073e9SAndroid Build Coastguard Worker            output = module(input1_tp, input2_tp)
6680*da0073e9SAndroid Build Coastguard Worker            output.backward(grad_output)
6681*da0073e9SAndroid Build Coastguard Worker
6682*da0073e9SAndroid Build Coastguard Worker            return output.data, input1.grad.data, input2.grad.data
6683*da0073e9SAndroid Build Coastguard Worker
6684*da0073e9SAndroid Build Coastguard Worker        out_nc, g1_nc, g2_nc = run(input1_tp, input2_tp)
6685*da0073e9SAndroid Build Coastguard Worker        input1_tp = input1_tp.contiguous()
6686*da0073e9SAndroid Build Coastguard Worker        input2_tp = input2_tp.contiguous()
6687*da0073e9SAndroid Build Coastguard Worker        out, g1, g2 = run(input1_tp, input2_tp)
6688*da0073e9SAndroid Build Coastguard Worker
6689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_nc)
6690*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g1, g1_nc)
6691*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g2, g2_nc)
6692*da0073e9SAndroid Build Coastguard Worker
6693*da0073e9SAndroid Build Coastguard Worker    def test_bilinear_no_bias(self):
6694*da0073e9SAndroid Build Coastguard Worker        module = nn.Bilinear(10, 10, 8, dtype=torch.double)
6695*da0073e9SAndroid Build Coastguard Worker        module_no_bias = nn.Bilinear(10, 10, 8, False, dtype=torch.double)
6696*da0073e9SAndroid Build Coastguard Worker
6697*da0073e9SAndroid Build Coastguard Worker        module.bias.data.zero_()
6698*da0073e9SAndroid Build Coastguard Worker        module.weight.data.copy_(module_no_bias.weight)
6699*da0073e9SAndroid Build Coastguard Worker
6700*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(4, 10, requires_grad=True, dtype=torch.double)
6701*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(4, 10, requires_grad=True, dtype=torch.double)
6702*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(4, 8, dtype=torch.double)
6703*da0073e9SAndroid Build Coastguard Worker
6704*da0073e9SAndroid Build Coastguard Worker        def run(net):
6705*da0073e9SAndroid Build Coastguard Worker            input1.grad = input2.grad = None
6706*da0073e9SAndroid Build Coastguard Worker            output = net(input1, input2)
6707*da0073e9SAndroid Build Coastguard Worker            output.backward(grad_output)
6708*da0073e9SAndroid Build Coastguard Worker
6709*da0073e9SAndroid Build Coastguard Worker            return output.data, input1.grad.data, input2.grad.data
6710*da0073e9SAndroid Build Coastguard Worker
6711*da0073e9SAndroid Build Coastguard Worker        out, g1, g2 = run(module)
6712*da0073e9SAndroid Build Coastguard Worker        out_nb, g1_nb, g2_nb = run(module_no_bias)
6713*da0073e9SAndroid Build Coastguard Worker
6714*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_nb)
6715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g1, g1_nb)
6716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g2, g2_nb)
6717*da0073e9SAndroid Build Coastguard Worker
6718*da0073e9SAndroid Build Coastguard Worker        _assertGradAndGradgradChecks(self,
6719*da0073e9SAndroid Build Coastguard Worker                                     lambda x1, x2: F.bilinear(x1, x2, module_no_bias.weight, module_no_bias.bias),
6720*da0073e9SAndroid Build Coastguard Worker                                     (input1, input2))
6721*da0073e9SAndroid Build Coastguard Worker
6722*da0073e9SAndroid Build Coastguard Worker    def test_bilinear_broadcasting(self):
6723*da0073e9SAndroid Build Coastguard Worker        m = nn.Bilinear(5, 6, 8)
6724*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(2, 3, 5)
6725*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(2, 3, 6)
6726*da0073e9SAndroid Build Coastguard Worker        expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8)
6727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, m(input1, input2))
6728*da0073e9SAndroid Build Coastguard Worker
6729*da0073e9SAndroid Build Coastguard Worker    def test_fold_invalid_arg(self):
6730*da0073e9SAndroid Build Coastguard Worker        # input.size(1) not divisible by \prod(kernel_size)
6731*da0073e9SAndroid Build Coastguard Worker
6732*da0073e9SAndroid Build Coastguard Worker        fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3))
6733*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"be divisible by the product of kernel_size"):
6734*da0073e9SAndroid Build Coastguard Worker            fold(torch.randn(1, 5, 9))
6735*da0073e9SAndroid Build Coastguard Worker
6736*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"be divisible by the product of kernel_size"):
6737*da0073e9SAndroid Build Coastguard Worker            fold(torch.randn(1, 19, 9))
6738*da0073e9SAndroid Build Coastguard Worker
6739*da0073e9SAndroid Build Coastguard Worker        # input.size(2) not matching the total number of sliding blocks
6740*da0073e9SAndroid Build Coastguard Worker
6741*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"):
6742*da0073e9SAndroid Build Coastguard Worker            fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3))
6743*da0073e9SAndroid Build Coastguard Worker            fold(torch.randn(1, 6, 10))
6744*da0073e9SAndroid Build Coastguard Worker
6745*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"):
6746*da0073e9SAndroid Build Coastguard Worker            fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3), stride=(2, 2))
6747*da0073e9SAndroid Build Coastguard Worker            fold(torch.randn(1, 6, 5))
6748*da0073e9SAndroid Build Coastguard Worker
6749*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"match the calculated number of sliding blocks"):
6750*da0073e9SAndroid Build Coastguard Worker            fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 3), stride=(2, 2), dilation=(1, 2), padding=(2, 0))
6751*da0073e9SAndroid Build Coastguard Worker            fold(torch.randn(1, 6, 5))  # should be 4 * 1 = 4 sliding blocks
6752*da0073e9SAndroid Build Coastguard Worker
6753*da0073e9SAndroid Build Coastguard Worker        fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2), stride=1, dilation=8, padding=0)
6754*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"calculated shape of the array of sliding blocks as"):
6755*da0073e9SAndroid Build Coastguard Worker            fold(torch.randn(1, 12, 12))
6756*da0073e9SAndroid Build Coastguard Worker
6757*da0073e9SAndroid Build Coastguard Worker    def test_unfold_invalid_arg(self):
6758*da0073e9SAndroid Build Coastguard Worker        # input wrong dimension
6759*da0073e9SAndroid Build Coastguard Worker
6760*da0073e9SAndroid Build Coastguard Worker        unfold = nn.Unfold(kernel_size=(2, 3))
6761*da0073e9SAndroid Build Coastguard Worker
6762*da0073e9SAndroid Build Coastguard Worker        # calculated output shape is too small
6763*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
6764*da0073e9SAndroid Build Coastguard Worker            unfold = nn.Unfold(kernel_size=(2, 3))
6765*da0073e9SAndroid Build Coastguard Worker            unfold(torch.randn(1, 2, 2, 2))
6766*da0073e9SAndroid Build Coastguard Worker
6767*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
6768*da0073e9SAndroid Build Coastguard Worker            unfold = nn.Unfold(kernel_size=(5, 3), padding=(1, 1))
6769*da0073e9SAndroid Build Coastguard Worker            unfold(torch.randn(1, 2, 2, 3))
6770*da0073e9SAndroid Build Coastguard Worker
6771*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
6772*da0073e9SAndroid Build Coastguard Worker            unfold = nn.Unfold(kernel_size=(1, 3), padding=(1, 1), dilation=(1, 2))
6773*da0073e9SAndroid Build Coastguard Worker            unfold(torch.randn(1, 2, 2, 2))
6774*da0073e9SAndroid Build Coastguard Worker
6775*da0073e9SAndroid Build Coastguard Worker    def test_softmin(self):
6776*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 16)
6777*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.softmin(x, 1), F.softmax(-x, 1))
6778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.softmin(x, 0), F.softmax(-x, 0))
6779*da0073e9SAndroid Build Coastguard Worker
6780*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_log_softmax(self):
6781*da0073e9SAndroid Build Coastguard Worker        # args validation
6782*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
6783*da0073e9SAndroid Build Coastguard Worker            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 15], div_value=2.)
6784*da0073e9SAndroid Build Coastguard Worker
6785*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
6786*da0073e9SAndroid Build Coastguard Worker            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 15, 10], div_value=2.)
6787*da0073e9SAndroid Build Coastguard Worker
6788*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
6789*da0073e9SAndroid Build Coastguard Worker            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 25], div_value=2.)
6790*da0073e9SAndroid Build Coastguard Worker
6791*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "cutoffs should be a sequence of unique,"):
6792*da0073e9SAndroid Build Coastguard Worker            _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 20], div_value=2.)
6793*da0073e9SAndroid Build Coastguard Worker
6794*da0073e9SAndroid Build Coastguard Worker        # not raise
6795*da0073e9SAndroid Build Coastguard Worker        _ = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 19], div_value=2.)
6796*da0073e9SAndroid Build Coastguard Worker
6797*da0073e9SAndroid Build Coastguard Worker        # input shapes
6798*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Input and target should have the same size"):
6799*da0073e9SAndroid Build Coastguard Worker            asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6800*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 16)
6801*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([0, 5, 10])
6802*da0073e9SAndroid Build Coastguard Worker            asfm(x, y)
6803*da0073e9SAndroid Build Coastguard Worker
6804*da0073e9SAndroid Build Coastguard Worker        # out-of-bound targets
6805*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Target values should be in"):
6806*da0073e9SAndroid Build Coastguard Worker            asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6807*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 16)
6808*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([0, 20])
6809*da0073e9SAndroid Build Coastguard Worker            asfm(x, y)
6810*da0073e9SAndroid Build Coastguard Worker
6811*da0073e9SAndroid Build Coastguard Worker        # cluster sizes
6812*da0073e9SAndroid Build Coastguard Worker        asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6813*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 16)
6814*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0, 17])
6815*da0073e9SAndroid Build Coastguard Worker
6816*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(asfm.head.weight.size(), (5 + 3, 16))   # 5 targets in head, 3 clusters, dimensionality 16
6817*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(asfm.tail[0][1].weight.size(), (5, 8))  # 5 targets in this cluster, dimensionality 8
6818*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(asfm.tail[1][1].weight.size(), (5, 4))
6819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(asfm.tail[2][1].weight.size(), (5, 2))
6820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(asfm(x, y).output.size(), (2, ))
6821*da0073e9SAndroid Build Coastguard Worker
6822*da0073e9SAndroid Build Coastguard Worker        # test no_batch_dim support
6823*da0073e9SAndroid Build Coastguard Worker        asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.)
6824*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 16)
6825*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([17])
6826*da0073e9SAndroid Build Coastguard Worker        x2 = x.squeeze(0)
6827*da0073e9SAndroid Build Coastguard Worker        y2 = y.squeeze(0)
6828*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(asfm(x, y).output.squeeze(0), asfm(x2, y2).output)
6829*da0073e9SAndroid Build Coastguard Worker
6830*da0073e9SAndroid Build Coastguard Worker        # log_probs actually returns log_proba
6831*da0073e9SAndroid Build Coastguard Worker        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 4, [2], div_value=2.)
6832*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 8)
6833*da0073e9SAndroid Build Coastguard Worker        logprob_out = asfm.log_prob(x)
6834*da0073e9SAndroid Build Coastguard Worker
6835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.exp(logprob_out).data.sum(1), torch.ones(4))
6836*da0073e9SAndroid Build Coastguard Worker
6837*da0073e9SAndroid Build Coastguard Worker        # forward returns the same thing as log_probs
6838*da0073e9SAndroid Build Coastguard Worker        for v in [0, 1, 2, 3]:
6839*da0073e9SAndroid Build Coastguard Worker            y = torch.full((4,), v, dtype=torch.long)
6840*da0073e9SAndroid Build Coastguard Worker            out, loss = asfm(x, y)
6841*da0073e9SAndroid Build Coastguard Worker
6842*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, logprob_out.gather(1, y.unsqueeze(1)).squeeze())
6843*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(loss, F.nll_loss(logprob_out, y))
6844*da0073e9SAndroid Build Coastguard Worker
6845*da0073e9SAndroid Build Coastguard Worker        # predict
6846*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(64, 8).abs_()
6847*da0073e9SAndroid Build Coastguard Worker
6848*da0073e9SAndroid Build Coastguard Worker        # argmax in shortlist
6849*da0073e9SAndroid Build Coastguard Worker        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True)
6850*da0073e9SAndroid Build Coastguard Worker        asfm.head.weight.data.abs_()
6851*da0073e9SAndroid Build Coastguard Worker        asfm.head.bias.data.abs_()
6852*da0073e9SAndroid Build Coastguard Worker        asfm.head.weight.data[asfm.shortlist_size:, :].zero_()
6853*da0073e9SAndroid Build Coastguard Worker
6854*da0073e9SAndroid Build Coastguard Worker        out = asfm.predict(x)
6855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, asfm.log_prob(x).argmax(dim=1))
6856*da0073e9SAndroid Build Coastguard Worker
6857*da0073e9SAndroid Build Coastguard Worker        # argmax outside of shortlist
6858*da0073e9SAndroid Build Coastguard Worker        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True)
6859*da0073e9SAndroid Build Coastguard Worker        asfm.head.weight.data.abs_()
6860*da0073e9SAndroid Build Coastguard Worker        asfm.head.bias.data.abs_()
6861*da0073e9SAndroid Build Coastguard Worker        asfm.head.weight.data[:asfm.shortlist_size, :].zero_()
6862*da0073e9SAndroid Build Coastguard Worker
6863*da0073e9SAndroid Build Coastguard Worker        out = asfm.predict(x)
6864*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, asfm.log_prob(x).argmax(dim=1))
6865*da0073e9SAndroid Build Coastguard Worker
6866*da0073e9SAndroid Build Coastguard Worker        # half of the argmax in shortlist, half in clusters
6867*da0073e9SAndroid Build Coastguard Worker        asfm = nn.AdaptiveLogSoftmaxWithLoss(8, 10, [4, 8], div_value=2., head_bias=True)
6868*da0073e9SAndroid Build Coastguard Worker        asfm.head.weight.data.abs_()
6869*da0073e9SAndroid Build Coastguard Worker        asfm.head.bias.data.abs_()
6870*da0073e9SAndroid Build Coastguard Worker
6871*da0073e9SAndroid Build Coastguard Worker        x[:32, :asfm.shortlist_size].zero_()
6872*da0073e9SAndroid Build Coastguard Worker        x[32:, asfm.shortlist_size:].zero_()
6873*da0073e9SAndroid Build Coastguard Worker
6874*da0073e9SAndroid Build Coastguard Worker        asfm.head.weight.data[:asfm.shortlist_size, asfm.shortlist_size:].zero_()
6875*da0073e9SAndroid Build Coastguard Worker        asfm.head.weight.data[asfm.shortlist_size:, :asfm.shortlist_size].zero_()
6876*da0073e9SAndroid Build Coastguard Worker
6877*da0073e9SAndroid Build Coastguard Worker        out = asfm.predict(x)
6878*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, asfm.log_prob(x).argmax(dim=1))
6879*da0073e9SAndroid Build Coastguard Worker
6880*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss(self, dtype=torch.bfloat16):
6881*da0073e9SAndroid Build Coastguard Worker        loss_cpu = nn.CrossEntropyLoss().cpu()
6882*da0073e9SAndroid Build Coastguard Worker        inputf = torch.randn(15, 10, device="cpu", dtype=torch.float, requires_grad=True)
6883*da0073e9SAndroid Build Coastguard Worker        input = inputf.to(dtype).detach().requires_grad_(True)
6884*da0073e9SAndroid Build Coastguard Worker        target = torch.empty(15, dtype=torch.long).random_(10)
6885*da0073e9SAndroid Build Coastguard Worker
6886*da0073e9SAndroid Build Coastguard Worker        outf = loss_cpu(inputf, target)
6887*da0073e9SAndroid Build Coastguard Worker        out = loss_cpu(input, target)
6888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, outf.to(dtype=dtype), atol=1e-1, rtol=0)
6889*da0073e9SAndroid Build Coastguard Worker
6890*da0073e9SAndroid Build Coastguard Worker        outf.backward()
6891*da0073e9SAndroid Build Coastguard Worker        out.backward()
6892*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, inputf.grad.to(dtype=dtype), atol=1e-1, rtol=0)
6893*da0073e9SAndroid Build Coastguard Worker
6894*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_precision(self):
6895*da0073e9SAndroid Build Coastguard Worker        # Regression test for #55657
6896*da0073e9SAndroid Build Coastguard Worker        loss_cpu = nn.CrossEntropyLoss().cpu()
6897*da0073e9SAndroid Build Coastguard Worker        inputf = torch.randn(128, 2, 768, 768, device="cpu", dtype=torch.float)
6898*da0073e9SAndroid Build Coastguard Worker        inputd = inputf.double()
6899*da0073e9SAndroid Build Coastguard Worker        target = torch.randint(2, (128, 768, 768), dtype=torch.long)
6900*da0073e9SAndroid Build Coastguard Worker
6901*da0073e9SAndroid Build Coastguard Worker        outf = loss_cpu(inputf, target)
6902*da0073e9SAndroid Build Coastguard Worker        outd = loss_cpu(inputd, target)
6903*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outf, outd, exact_dtype=False)
6904*da0073e9SAndroid Build Coastguard Worker
6905*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_zero_div(self):
6906*da0073e9SAndroid Build Coastguard Worker        # Test for issue #73165
6907*da0073e9SAndroid Build Coastguard Worker        input_1 = torch.rand([5, 0], dtype=torch.float32)
6908*da0073e9SAndroid Build Coastguard Worker        input_2 = torch.rand([5, 0], dtype=torch.float32)
6909*da0073e9SAndroid Build Coastguard Worker        torch.nn.CrossEntropyLoss()(input_1, input_2)
6910*da0073e9SAndroid Build Coastguard Worker
6911*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
6912*da0073e9SAndroid Build Coastguard Worker    def test_convert_sync_batchnorm(self):
6913*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.Sequential(
6914*da0073e9SAndroid Build Coastguard Worker            torch.nn.BatchNorm1d(100),
6915*da0073e9SAndroid Build Coastguard Worker            torch.nn.InstanceNorm1d(100)
6916*da0073e9SAndroid Build Coastguard Worker        ).cuda()
6917*da0073e9SAndroid Build Coastguard Worker
6918*da0073e9SAndroid Build Coastguard Worker        # necessary to have an anchor point for comparison, in case the
6919*da0073e9SAndroid Build Coastguard Worker        # convert_sync_batchnorm updates in place
6920*da0073e9SAndroid Build Coastguard Worker        comp_module = torch.nn.Sequential(
6921*da0073e9SAndroid Build Coastguard Worker            torch.nn.BatchNorm1d(100),
6922*da0073e9SAndroid Build Coastguard Worker            torch.nn.InstanceNorm1d(100)
6923*da0073e9SAndroid Build Coastguard Worker        ).cuda()
6924*da0073e9SAndroid Build Coastguard Worker        comp_module.load_state_dict(module.state_dict())
6925*da0073e9SAndroid Build Coastguard Worker
6926*da0073e9SAndroid Build Coastguard Worker        sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
6927*da0073e9SAndroid Build Coastguard Worker        children = list(sync_bn_module.children())
6928*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(children[0].__class__, torch.nn.SyncBatchNorm)
6929*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(children[1].__class__, torch.nn.InstanceNorm1d)
6930*da0073e9SAndroid Build Coastguard Worker
6931*da0073e9SAndroid Build Coastguard Worker        for layer, converted_layer in zip(comp_module.children(), sync_bn_module.children()):
6932*da0073e9SAndroid Build Coastguard Worker            for key in layer.state_dict().keys():
6933*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(layer.state_dict()[key].device, converted_layer.state_dict()[key].device)
6934*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key])
6935*da0073e9SAndroid Build Coastguard Worker
6936*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
6937*da0073e9SAndroid Build Coastguard Worker    def test_sync_batchnorm_backward_elemt(self):
6938*da0073e9SAndroid Build Coastguard Worker        device = 'cuda'
6939*da0073e9SAndroid Build Coastguard Worker        saved_input = torch.rand(2, 3, 2, 1, device=device)
6940*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.rand(2, 3, 2, 1, device=device)
6941*da0073e9SAndroid Build Coastguard Worker        mean = torch.rand(3, device=device)
6942*da0073e9SAndroid Build Coastguard Worker        invstd = torch.rand(3, device=device)
6943*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(3, device=device)
6944*da0073e9SAndroid Build Coastguard Worker        sum_dy = torch.rand(3, device=device)
6945*da0073e9SAndroid Build Coastguard Worker        sum_dy_xmu = torch.rand(3, device=device)
6946*da0073e9SAndroid Build Coastguard Worker        count_tensor = torch.tensor([5, 5, 5], dtype=torch.int32, device=device)
6947*da0073e9SAndroid Build Coastguard Worker
6948*da0073e9SAndroid Build Coastguard Worker        gI_contiguous = torch.batch_norm_backward_elemt(
6949*da0073e9SAndroid Build Coastguard Worker            grad_output,
6950*da0073e9SAndroid Build Coastguard Worker            saved_input,
6951*da0073e9SAndroid Build Coastguard Worker            mean,
6952*da0073e9SAndroid Build Coastguard Worker            invstd,
6953*da0073e9SAndroid Build Coastguard Worker            weight,
6954*da0073e9SAndroid Build Coastguard Worker            sum_dy,
6955*da0073e9SAndroid Build Coastguard Worker            sum_dy_xmu,
6956*da0073e9SAndroid Build Coastguard Worker            count_tensor
6957*da0073e9SAndroid Build Coastguard Worker        )
6958*da0073e9SAndroid Build Coastguard Worker
6959*da0073e9SAndroid Build Coastguard Worker        # Test batch_norm_backward_elemt gives the same answer for all
6960*da0073e9SAndroid Build Coastguard Worker        # combinations of contiguous as channels_last input
6961*da0073e9SAndroid Build Coastguard Worker        for a, b in [
6962*da0073e9SAndroid Build Coastguard Worker                (torch.channels_last, torch.contiguous_format),
6963*da0073e9SAndroid Build Coastguard Worker                (torch.contiguous_format, torch.channels_last),
6964*da0073e9SAndroid Build Coastguard Worker                (torch.channels_last, torch.channels_last),
6965*da0073e9SAndroid Build Coastguard Worker        ]:
6966*da0073e9SAndroid Build Coastguard Worker            gI_actual = torch.batch_norm_backward_elemt(
6967*da0073e9SAndroid Build Coastguard Worker                grad_output.contiguous(memory_format=a),
6968*da0073e9SAndroid Build Coastguard Worker                saved_input.contiguous(memory_format=b),
6969*da0073e9SAndroid Build Coastguard Worker                mean,
6970*da0073e9SAndroid Build Coastguard Worker                invstd,
6971*da0073e9SAndroid Build Coastguard Worker                weight,
6972*da0073e9SAndroid Build Coastguard Worker                sum_dy,
6973*da0073e9SAndroid Build Coastguard Worker                sum_dy_xmu,
6974*da0073e9SAndroid Build Coastguard Worker                count_tensor
6975*da0073e9SAndroid Build Coastguard Worker            )
6976*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gI_actual, gI_contiguous)
6977*da0073e9SAndroid Build Coastguard Worker
6978*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
6979*da0073e9SAndroid Build Coastguard Worker    def test_sync_batchnorm_accuracy_cuda(self):
6980*da0073e9SAndroid Build Coastguard Worker        # The target of this test is to test the functionality and accuracy of
6981*da0073e9SAndroid Build Coastguard Worker        #   those single-GPU cuda kernels used in SyncBatchNorm
6982*da0073e9SAndroid Build Coastguard Worker        # They are:
6983*da0073e9SAndroid Build Coastguard Worker        #   fwd: torch.batch_norm_stats, torch.batch_norm_gather_stats_with_counts, torch.batch_norm_elemt
6984*da0073e9SAndroid Build Coastguard Worker        #   bwd: torch.batch_norm_backward_reduce, torch.batch_norm_backward_elemt
6985*da0073e9SAndroid Build Coastguard Worker
6986*da0073e9SAndroid Build Coastguard Worker        def _batch_norm_stats(data, memory_format, mean_axes):
6987*da0073e9SAndroid Build Coastguard Worker            mean1, _ = torch.batch_norm_stats(data, 1e-5)
6988*da0073e9SAndroid Build Coastguard Worker            mean2, _ = torch.batch_norm_stats(data.to(memory_format=memory_format), 1e-5)
6989*da0073e9SAndroid Build Coastguard Worker            mean_ref = torch.mean(data, mean_axes, keepdim=False)
6990*da0073e9SAndroid Build Coastguard Worker
6991*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mean_ref, mean1)
6992*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mean_ref, mean2)
6993*da0073e9SAndroid Build Coastguard Worker
6994*da0073e9SAndroid Build Coastguard Worker        _batch_norm_stats(torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last, (0, 2, 3))
6995*da0073e9SAndroid Build Coastguard Worker        _batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last_3d, (0, 2, 3, 4))
6996*da0073e9SAndroid Build Coastguard Worker
6997*da0073e9SAndroid Build Coastguard Worker    def test_flatten(self):
6998*da0073e9SAndroid Build Coastguard Worker        tensor_input = torch.randn(2, 1, 2, 3)
6999*da0073e9SAndroid Build Coastguard Worker
7000*da0073e9SAndroid Build Coastguard Worker        # Flatten Tensor
7001*da0073e9SAndroid Build Coastguard Worker
7002*da0073e9SAndroid Build Coastguard Worker        flatten = nn.Flatten(start_dim=1, end_dim=-1)
7003*da0073e9SAndroid Build Coastguard Worker        tensor_output = flatten(tensor_input)
7004*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_output.size(), torch.Size([2, 6]))
7005*da0073e9SAndroid Build Coastguard Worker
7006*da0073e9SAndroid Build Coastguard Worker    def test_unflatten(self):
7007*da0073e9SAndroid Build Coastguard Worker        tensor_input = torch.randn(2, 50)
7008*da0073e9SAndroid Build Coastguard Worker
7009*da0073e9SAndroid Build Coastguard Worker        # Unflatten Tensor (unflattened_size as a tuple of ints and list of ints)
7010*da0073e9SAndroid Build Coastguard Worker
7011*da0073e9SAndroid Build Coastguard Worker        for us in ((2, 5, 5), [2, 5, 5]):
7012*da0073e9SAndroid Build Coastguard Worker            unflatten = nn.Unflatten(dim=1, unflattened_size=us)
7013*da0073e9SAndroid Build Coastguard Worker            tensor_output = unflatten(tensor_input)
7014*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_output.size(), torch.Size([2, 2, 5, 5]))
7015*da0073e9SAndroid Build Coastguard Worker
7016*da0073e9SAndroid Build Coastguard Worker        # Unflatten NamedTensor
7017*da0073e9SAndroid Build Coastguard Worker
7018*da0073e9SAndroid Build Coastguard Worker        unflatten = nn.Unflatten(dim='features', unflattened_size=(('C', 2), ('H', 5), ('W', 5)))
7019*da0073e9SAndroid Build Coastguard Worker        named_tensor_input = tensor_input.refine_names('N', 'features')
7020*da0073e9SAndroid Build Coastguard Worker        named_tensor_output = unflatten(named_tensor_input)
7021*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor_output.size(), torch.Size([2, 2, 5, 5]))
7022*da0073e9SAndroid Build Coastguard Worker
7023*da0073e9SAndroid Build Coastguard Worker    def test_unflatten_invalid_arg(self):
7024*da0073e9SAndroid Build Coastguard Worker        # Wrong type for unflattened_size (tuple of floats)
7025*da0073e9SAndroid Build Coastguard Worker
7026*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7027*da0073e9SAndroid Build Coastguard Worker                TypeError,
7028*da0073e9SAndroid Build Coastguard Worker                r"unflattened_size must be tuple of ints, but found element of type float at pos 2"):
7029*da0073e9SAndroid Build Coastguard Worker            nn.Unflatten(dim=1, unflattened_size=(2, 5, 5.0))
7030*da0073e9SAndroid Build Coastguard Worker
7031*da0073e9SAndroid Build Coastguard Worker        # Wrong type for unflattened_size (list of lists and list of tuples)
7032*da0073e9SAndroid Build Coastguard Worker        for us in ([['C', 2], ['W', 5], ['H', 5]], [('C', 2), ('W', 5), ('H', 5)]):
7033*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
7034*da0073e9SAndroid Build Coastguard Worker                    TypeError,
7035*da0073e9SAndroid Build Coastguard Worker                    r"unflattened_size must be a tuple of tuples, but found type list"):
7036*da0073e9SAndroid Build Coastguard Worker                nn.Unflatten(dim='features', unflattened_size=us)
7037*da0073e9SAndroid Build Coastguard Worker
7038*da0073e9SAndroid Build Coastguard Worker        # Wrong type for unflattened_size (tuple of lists)
7039*da0073e9SAndroid Build Coastguard Worker
7040*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7041*da0073e9SAndroid Build Coastguard Worker                TypeError,
7042*da0073e9SAndroid Build Coastguard Worker                r"unflattened_size must be tuple of tuples, but found element of type list at pos 0"):
7043*da0073e9SAndroid Build Coastguard Worker            nn.Unflatten(dim='features', unflattened_size=(['C', 2], ['W', 5], ['H', 5]))
7044*da0073e9SAndroid Build Coastguard Worker
7045*da0073e9SAndroid Build Coastguard Worker        # Wrong type for unflattened_size (tuple of dicts)
7046*da0073e9SAndroid Build Coastguard Worker
7047*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7048*da0073e9SAndroid Build Coastguard Worker                TypeError,
7049*da0073e9SAndroid Build Coastguard Worker                r"unflattened_size must be tuple of tuples, but found element of type dict at pos 0"):
7050*da0073e9SAndroid Build Coastguard Worker            nn.Unflatten(dim='features', unflattened_size=({'C': 2}, {'W': 5}, {'H': 5}))
7051*da0073e9SAndroid Build Coastguard Worker
7052*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_grads_with_create_graph_flag(self):
7053*da0073e9SAndroid Build Coastguard Worker        atol = 1e-5
7054*da0073e9SAndroid Build Coastguard Worker        rtol = 1e-3
7055*da0073e9SAndroid Build Coastguard Worker
7056*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((4, 4, 16), requires_grad=True)
7057*da0073e9SAndroid Build Coastguard Worker        layer_norm = nn.LayerNorm((16,), 1e-5, True)
7058*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
7059*da0073e9SAndroid Build Coastguard Worker            layer_norm.weight = torch.nn.Parameter(0.1 * torch.ones_like(layer_norm.weight))
7060*da0073e9SAndroid Build Coastguard Worker
7061*da0073e9SAndroid Build Coastguard Worker        grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0]
7062*da0073e9SAndroid Build Coastguard Worker        grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0]
7063*da0073e9SAndroid Build Coastguard Worker
7064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grads1, grads2, rtol=rtol, atol=atol)
7065*da0073e9SAndroid Build Coastguard Worker
7066*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
7067*da0073e9SAndroid Build Coastguard Worker            x = x.to('cuda')
7068*da0073e9SAndroid Build Coastguard Worker            layer_norm = layer_norm.to('cuda')
7069*da0073e9SAndroid Build Coastguard Worker
7070*da0073e9SAndroid Build Coastguard Worker            grads1 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=False)[0]
7071*da0073e9SAndroid Build Coastguard Worker            grads2 = torch.autograd.grad(layer_norm(x).sum(), x, create_graph=True)[0]
7072*da0073e9SAndroid Build Coastguard Worker
7073*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grads1, grads2, rtol=rtol, atol=atol)
7074*da0073e9SAndroid Build Coastguard Worker
7075*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_eps(self):
7076*da0073e9SAndroid Build Coastguard Worker        # test for https://github.com/pytorch/pytorch/issues/108072
7077*da0073e9SAndroid Build Coastguard Worker        x = torch.Tensor([[[2.0, 2.0], [14.0, 14.0]], [[2.0, 2.0], [14.0, 14.0]]])
7078*da0073e9SAndroid Build Coastguard Worker        ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
7079*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ln.forward(x), torch.zeros_like(x))
7080*da0073e9SAndroid Build Coastguard Worker
7081*da0073e9SAndroid Build Coastguard Worker    def test_padding_list(self):
7082*da0073e9SAndroid Build Coastguard Worker        # Padding can be a list, or tuple (regression test for gh-54452)
7083*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 8, 32, 32)
7084*da0073e9SAndroid Build Coastguard Worker        net = torch.nn.ConvTranspose2d(8, 16, kernel_size=3, padding=[3, 3])
7085*da0073e9SAndroid Build Coastguard Worker        y = net(x)
7086*da0073e9SAndroid Build Coastguard Worker
7087*da0073e9SAndroid Build Coastguard Worker        net = torch.nn.ConvTranspose2d(8, 16, kernel_size=3, padding=(3, 3))
7088*da0073e9SAndroid Build Coastguard Worker        y = net(x)
7089*da0073e9SAndroid Build Coastguard Worker
7090*da0073e9SAndroid Build Coastguard Worker    def test_fractional_max_pool2d_invalid_output_ratio(self):
7091*da0073e9SAndroid Build Coastguard Worker        arg_1 = [2, 1]
7092*da0073e9SAndroid Build Coastguard Worker        arg_2 = [0.5, 0.5, 0.6]
7093*da0073e9SAndroid Build Coastguard Worker        arg_class = torch.nn.FractionalMaxPool2d(kernel_size=arg_1, output_ratio=arg_2,)
7094*da0073e9SAndroid Build Coastguard Worker        arg_3_0_tensor = torch.rand([20, 16, 50, 32], dtype=torch.float32)
7095*da0073e9SAndroid Build Coastguard Worker        arg_3_0 = arg_3_0_tensor.clone()
7096*da0073e9SAndroid Build Coastguard Worker        arg_3 = [arg_3_0,]
7097*da0073e9SAndroid Build Coastguard Worker
7098*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError,
7099*da0073e9SAndroid Build Coastguard Worker                                    "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
7100*da0073e9SAndroid Build Coastguard Worker            res = arg_class(*arg_3)
7101*da0073e9SAndroid Build Coastguard Worker
7102*da0073e9SAndroid Build Coastguard Worker    def test_max_pool1d_invalid_output_size(self):
7103*da0073e9SAndroid Build Coastguard Worker        arg_1 = 3
7104*da0073e9SAndroid Build Coastguard Worker        arg_2 = 255
7105*da0073e9SAndroid Build Coastguard Worker        arg_3 = False
7106*da0073e9SAndroid Build Coastguard Worker        arg_class = torch.nn.MaxPool1d(kernel_size=arg_1, stride=arg_2, return_indices=arg_3)
7107*da0073e9SAndroid Build Coastguard Worker        arg_4_0 = torch.as_tensor([[0.3204]])
7108*da0073e9SAndroid Build Coastguard Worker        arg_4 = [arg_4_0,]
7109*da0073e9SAndroid Build Coastguard Worker
7110*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
7111*da0073e9SAndroid Build Coastguard Worker            res = arg_class(*arg_4)
7112*da0073e9SAndroid Build Coastguard Worker
7113*da0073e9SAndroid Build Coastguard Worker    def test_pickle_module_no_weights_only_warning(self):
7114*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
7115*da0073e9SAndroid Build Coastguard Worker            pickle.loads(pickle.dumps(torch.nn.Linear(10, 10)))
7116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(w), 0)
7117*da0073e9SAndroid Build Coastguard Worker
7118*da0073e9SAndroid Build Coastguard Workerclass TestFusionEval(TestCase):
7119*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
7120*da0073e9SAndroid Build Coastguard Worker    @given(X=hu.tensor(shapes=((5, 3, 5, 5),), dtype=np.double),
7121*da0073e9SAndroid Build Coastguard Worker           running_mean=hu.tensor(shapes=(6,), dtype=np.double),
7122*da0073e9SAndroid Build Coastguard Worker           running_var=hu.tensor(shapes=(6,), dtype=np.double))
7123*da0073e9SAndroid Build Coastguard Worker    def test_fuse_module_eval_numerics(self, X, running_mean, running_var):
7124*da0073e9SAndroid Build Coastguard Worker        inputs, _ = X
7125*da0073e9SAndroid Build Coastguard Worker
7126*da0073e9SAndroid Build Coastguard Worker        iC, oC = inputs.shape[1], len(running_mean[0])
7127*da0073e9SAndroid Build Coastguard Worker        inputs = torch.from_numpy(inputs)
7128*da0073e9SAndroid Build Coastguard Worker        kernel_size = (3, 3)
7129*da0073e9SAndroid Build Coastguard Worker
7130*da0073e9SAndroid Build Coastguard Worker        conv_ref = torch.nn.Conv2d(iC, oC, bias=True, kernel_size=kernel_size)
7131*da0073e9SAndroid Build Coastguard Worker        bn_ref = torch.nn.BatchNorm2d(oC)
7132*da0073e9SAndroid Build Coastguard Worker        bn_ref.running_mean = torch.from_numpy(running_mean[0])
7133*da0073e9SAndroid Build Coastguard Worker        bn_ref.running_var = torch.from_numpy(running_var[0])
7134*da0073e9SAndroid Build Coastguard Worker
7135*da0073e9SAndroid Build Coastguard Worker        conv_ref.eval()
7136*da0073e9SAndroid Build Coastguard Worker        bn_ref.eval()
7137*da0073e9SAndroid Build Coastguard Worker
7138*da0073e9SAndroid Build Coastguard Worker        Y_ref = bn_ref(conv_ref(inputs))
7139*da0073e9SAndroid Build Coastguard Worker        conv_bn_fused = torch.nn.utils.fusion.fuse_conv_bn_eval(conv_ref,
7140*da0073e9SAndroid Build Coastguard Worker                                                                bn_ref)
7141*da0073e9SAndroid Build Coastguard Worker        Y_hat = conv_bn_fused(inputs)
7142*da0073e9SAndroid Build Coastguard Worker
7143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Y_ref, Y_hat, msg="Conv+BN fusion results are off")
7144*da0073e9SAndroid Build Coastguard Worker
7145*da0073e9SAndroid Build Coastguard Worker        na_bn_ref = torch.nn.BatchNorm2d(oC, affine=False)
7146*da0073e9SAndroid Build Coastguard Worker        na_bn_ref.running_mean = torch.from_numpy(running_mean[0])
7147*da0073e9SAndroid Build Coastguard Worker        na_bn_ref.running_var = torch.from_numpy(running_var[0])
7148*da0073e9SAndroid Build Coastguard Worker        na_bn_ref.eval()
7149*da0073e9SAndroid Build Coastguard Worker
7150*da0073e9SAndroid Build Coastguard Worker        Y_ref = na_bn_ref(conv_ref(inputs))
7151*da0073e9SAndroid Build Coastguard Worker        conv_na_bn_fused = torch.nn.utils.fusion.fuse_conv_bn_eval(conv_ref,
7152*da0073e9SAndroid Build Coastguard Worker                                                                   na_bn_ref)
7153*da0073e9SAndroid Build Coastguard Worker        Y_hat = conv_na_bn_fused(inputs)
7154*da0073e9SAndroid Build Coastguard Worker
7155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Y_ref, Y_hat, msg="Conv+BN(non-affine) fusion results are off")
7156*da0073e9SAndroid Build Coastguard Worker
7157*da0073e9SAndroid Build Coastguard Worker
7158*da0073e9SAndroid Build Coastguard Workerclass TestConstantPadNd(TestCase):
7159*da0073e9SAndroid Build Coastguard Worker    def test_constant_pad_nd(self):
7160*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[1, 2], [3, 4]])
7161*da0073e9SAndroid Build Coastguard Worker        res = torch.constant_pad_nd(a, [1, 2, 1, 0], 9)
7162*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([
7163*da0073e9SAndroid Build Coastguard Worker            [9, 9, 9, 9, 9],
7164*da0073e9SAndroid Build Coastguard Worker            [9, 1, 2, 9, 9],
7165*da0073e9SAndroid Build Coastguard Worker            [9, 3, 4, 9, 9]
7166*da0073e9SAndroid Build Coastguard Worker        ])
7167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, expected)
7168*da0073e9SAndroid Build Coastguard Worker
7169*da0073e9SAndroid Build Coastguard Worker    def test_preserves_memory_format(self):
7170*da0073e9SAndroid Build Coastguard Worker        nchw_tensor = torch.rand((1, 2, 5, 3))
7171*da0073e9SAndroid Build Coastguard Worker        nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
7172*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
7173*da0073e9SAndroid Build Coastguard Worker
7174*da0073e9SAndroid Build Coastguard Worker        nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
7175*da0073e9SAndroid Build Coastguard Worker        nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
7176*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
7177*da0073e9SAndroid Build Coastguard Worker
7178*da0073e9SAndroid Build Coastguard Worker
7179*da0073e9SAndroid Build Coastguard Workerclass TestAddRelu(TestCase):
7180*da0073e9SAndroid Build Coastguard Worker    def test_add_relu(self):
7181*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((7, 11))
7182*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((7, 11))
7183*da0073e9SAndroid Build Coastguard Worker        a = a.float()
7184*da0073e9SAndroid Build Coastguard Worker        b = b.float()
7185*da0073e9SAndroid Build Coastguard Worker        a = a * -10
7186*da0073e9SAndroid Build Coastguard Worker        a = a + 5
7187*da0073e9SAndroid Build Coastguard Worker        add_res = a + b
7188*da0073e9SAndroid Build Coastguard Worker        relu_res = torch.relu(add_res)
7189*da0073e9SAndroid Build Coastguard Worker        add_relu_res = torch._VF._add_relu(a, b)
7190*da0073e9SAndroid Build Coastguard Worker
7191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(add_relu_res, relu_res)
7192*da0073e9SAndroid Build Coastguard Worker
7193*da0073e9SAndroid Build Coastguard Worker    def test_add_relu_broadcasting(self):
7194*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((1, 32))
7195*da0073e9SAndroid Build Coastguard Worker        b = 1
7196*da0073e9SAndroid Build Coastguard Worker        b_scalar = torch.ones(1, 32)
7197*da0073e9SAndroid Build Coastguard Worker        res = torch._VF._add_relu(a, b)
7198*da0073e9SAndroid Build Coastguard Worker        broadcasted_res = torch._VF._add_relu(a, b_scalar)
7199*da0073e9SAndroid Build Coastguard Worker
7200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(broadcasted_res, res)
7201*da0073e9SAndroid Build Coastguard Worker
7202*da0073e9SAndroid Build Coastguard Worker
7203*da0073e9SAndroid Build Coastguard Workerdef add_test(test, decorator=None):
7204*da0073e9SAndroid Build Coastguard Worker    def add(test_name, fn):
7205*da0073e9SAndroid Build Coastguard Worker        if hasattr(TestNN, test_name):
7206*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError('Found two tests with the same name: ' + test_name)
7207*da0073e9SAndroid Build Coastguard Worker        if decorator is not None:
7208*da0073e9SAndroid Build Coastguard Worker            fn = decorator(fn)
7209*da0073e9SAndroid Build Coastguard Worker        setattr(TestNN, test_name, fn)
7210*da0073e9SAndroid Build Coastguard Worker
7211*da0073e9SAndroid Build Coastguard Worker    test_name = test.get_name()
7212*da0073e9SAndroid Build Coastguard Worker    if not hasattr(test, 'test_cpu') or test.test_cpu:
7213*da0073e9SAndroid Build Coastguard Worker        add(test_name, lambda self, test=test: test(self))
7214*da0073e9SAndroid Build Coastguard Worker    cuda_test_name = test_name + '_cuda'
7215*da0073e9SAndroid Build Coastguard Worker    # With dtype enable, it's good enough to test against three floating types
7216*da0073e9SAndroid Build Coastguard Worker    kwargs = {}
7217*da0073e9SAndroid Build Coastguard Worker    if 'extra_args' in get_function_arglist(test.test_cuda):
7218*da0073e9SAndroid Build Coastguard Worker        kwargs['extra_args'] = test.extra_args
7219*da0073e9SAndroid Build Coastguard Worker
7220*da0073e9SAndroid Build Coastguard Worker    if 'dtype' in get_function_arglist(test.test_cuda):
7221*da0073e9SAndroid Build Coastguard Worker        if tf32_is_not_fp32() and test.with_tf32:
7222*da0073e9SAndroid Build Coastguard Worker
7223*da0073e9SAndroid Build Coastguard Worker            def with_tf32_off(self, test=test, kwargs=kwargs):
7224*da0073e9SAndroid Build Coastguard Worker                with tf32_off():
7225*da0073e9SAndroid Build Coastguard Worker                    test.test_cuda(self, dtype=torch.float, **kwargs)
7226*da0073e9SAndroid Build Coastguard Worker
7227*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_fp32', with_tf32_off)
7228*da0073e9SAndroid Build Coastguard Worker
7229*da0073e9SAndroid Build Coastguard Worker            def with_tf32_on(self, test=test, kwargs=kwargs):
7230*da0073e9SAndroid Build Coastguard Worker                with tf32_on(self, test.tf32_precision):
7231*da0073e9SAndroid Build Coastguard Worker                    test.test_cuda(self, dtype=torch.float, **kwargs)
7232*da0073e9SAndroid Build Coastguard Worker
7233*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_tf32', with_tf32_on)
7234*da0073e9SAndroid Build Coastguard Worker        else:
7235*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_float', lambda self,
7236*da0073e9SAndroid Build Coastguard Worker                test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.float, **kwargs))
7237*da0073e9SAndroid Build Coastguard Worker        add(cuda_test_name + '_double', lambda self,
7238*da0073e9SAndroid Build Coastguard Worker            test=test, kwargs=kwargs: test.test_cuda(self, dtype=torch.double, **kwargs))
7239*da0073e9SAndroid Build Coastguard Worker
7240*da0073e9SAndroid Build Coastguard Worker        def test_half(self, test=test, kwargs=kwargs):
7241*da0073e9SAndroid Build Coastguard Worker            test.test_cuda(self, dtype=torch.half, **kwargs)
7242*da0073e9SAndroid Build Coastguard Worker        if getattr(test, 'check_half', True):
7243*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_half', test_half)
7244*da0073e9SAndroid Build Coastguard Worker
7245*da0073e9SAndroid Build Coastguard Worker        def test_bfloat16(self, test=test, kwargs=kwargs):
7246*da0073e9SAndroid Build Coastguard Worker            test.test_cuda(self, dtype=torch.bfloat16, **kwargs)
7247*da0073e9SAndroid Build Coastguard Worker        if getattr(test, 'check_bfloat16', True):
7248*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_bfloat16', test_bfloat16)
7249*da0073e9SAndroid Build Coastguard Worker
7250*da0073e9SAndroid Build Coastguard Worker        def test_cfloat(self, test=test, kwargs=kwargs):
7251*da0073e9SAndroid Build Coastguard Worker            test.test_cuda(self, dtype=torch.cfloat, **kwargs)
7252*da0073e9SAndroid Build Coastguard Worker
7253*da0073e9SAndroid Build Coastguard Worker        def test_cdouble(self, test=test, kwargs=kwargs):
7254*da0073e9SAndroid Build Coastguard Worker            test.test_cuda(self, dtype=torch.cdouble, **kwargs)
7255*da0073e9SAndroid Build Coastguard Worker        if getattr(test, 'check_complex', False):
7256*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_cfloat', test_cfloat)
7257*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_cdouble', test_cdouble)
7258*da0073e9SAndroid Build Coastguard Worker
7259*da0073e9SAndroid Build Coastguard Worker    else:
7260*da0073e9SAndroid Build Coastguard Worker        def with_tf32_off(self, test=test, kwargs=kwargs):
7261*da0073e9SAndroid Build Coastguard Worker            with tf32_off():
7262*da0073e9SAndroid Build Coastguard Worker                test.test_cuda(self, **kwargs)
7263*da0073e9SAndroid Build Coastguard Worker
7264*da0073e9SAndroid Build Coastguard Worker        if tf32_is_not_fp32() and test.with_tf32:
7265*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_fp32', with_tf32_off)
7266*da0073e9SAndroid Build Coastguard Worker
7267*da0073e9SAndroid Build Coastguard Worker            def with_tf32_on(self, test=test, kwargs=kwargs):
7268*da0073e9SAndroid Build Coastguard Worker                with tf32_on(self, test.tf32_precision):
7269*da0073e9SAndroid Build Coastguard Worker                    test.test_cuda(self, **kwargs)
7270*da0073e9SAndroid Build Coastguard Worker
7271*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name + '_tf32', with_tf32_on)
7272*da0073e9SAndroid Build Coastguard Worker        else:
7273*da0073e9SAndroid Build Coastguard Worker            add(cuda_test_name, with_tf32_off)
7274*da0073e9SAndroid Build Coastguard Worker
7275*da0073e9SAndroid Build Coastguard Workerfor test_params in module_tests + new_module_tests:
7276*da0073e9SAndroid Build Coastguard Worker    # TODO: CUDA is not implemented yet
7277*da0073e9SAndroid Build Coastguard Worker    if 'constructor' not in test_params:
7278*da0073e9SAndroid Build Coastguard Worker        name = test_params.pop('module_name')
7279*da0073e9SAndroid Build Coastguard Worker        test_params['constructor'] = getattr(nn, name)
7280*da0073e9SAndroid Build Coastguard Worker    decorator = test_params.pop('decorator', None)
7281*da0073e9SAndroid Build Coastguard Worker    test = NewModuleTest(**test_params)
7282*da0073e9SAndroid Build Coastguard Worker    add_test(test, decorator)
7283*da0073e9SAndroid Build Coastguard Worker    if 'check_eval' in test_params:
7284*da0073e9SAndroid Build Coastguard Worker        # create a new test that is identical but that sets module.training to False
7285*da0073e9SAndroid Build Coastguard Worker        desc = test_params.get('desc', None)
7286*da0073e9SAndroid Build Coastguard Worker        test_params['desc'] = 'eval' if desc is None else desc + '_eval'
7287*da0073e9SAndroid Build Coastguard Worker
7288*da0073e9SAndroid Build Coastguard Worker        def gen_eval_constructor(constructor):
7289*da0073e9SAndroid Build Coastguard Worker            def eval_constructor(*args, **kwargs):
7290*da0073e9SAndroid Build Coastguard Worker                cons = constructor(*args, **kwargs)
7291*da0073e9SAndroid Build Coastguard Worker                cons.training = False
7292*da0073e9SAndroid Build Coastguard Worker                return cons
7293*da0073e9SAndroid Build Coastguard Worker            eval_constructor.__name__ = constructor.__name__
7294*da0073e9SAndroid Build Coastguard Worker            return eval_constructor
7295*da0073e9SAndroid Build Coastguard Worker
7296*da0073e9SAndroid Build Coastguard Worker        test_params['constructor'] = gen_eval_constructor(test_params['constructor'])
7297*da0073e9SAndroid Build Coastguard Worker        test = NewModuleTest(**test_params)
7298*da0073e9SAndroid Build Coastguard Worker        add_test(test, decorator)
7299*da0073e9SAndroid Build Coastguard Worker    if 'check_with_long_tensor' in test_params:
7300*da0073e9SAndroid Build Coastguard Worker        fullname = test_params.get('fullname', None)
7301*da0073e9SAndroid Build Coastguard Worker        if fullname:
7302*da0073e9SAndroid Build Coastguard Worker            test_params['fullname'] = fullname + '_with_long_tensor'
7303*da0073e9SAndroid Build Coastguard Worker        else:
7304*da0073e9SAndroid Build Coastguard Worker            desc = test_params.get('desc', None)
7305*da0073e9SAndroid Build Coastguard Worker            test_params['desc'] = 'with_long_tensor' if desc is None else desc + '_with_long_tensor'
7306*da0073e9SAndroid Build Coastguard Worker
7307*da0073e9SAndroid Build Coastguard Worker        def double_equivalent_of_long_tensor(size):
7308*da0073e9SAndroid Build Coastguard Worker            return torch.randint(-1000, 1000, size=size).double()
7309*da0073e9SAndroid Build Coastguard Worker
7310*da0073e9SAndroid Build Coastguard Worker        def apply_to_cons(t):
7311*da0073e9SAndroid Build Coastguard Worker            if t.is_floating_point():
7312*da0073e9SAndroid Build Coastguard Worker                if isinstance(t, Parameter):
7313*da0073e9SAndroid Build Coastguard Worker                    return Parameter(double_equivalent_of_long_tensor(t.size()))
7314*da0073e9SAndroid Build Coastguard Worker                elif isinstance(t, torch.Tensor):
7315*da0073e9SAndroid Build Coastguard Worker                    return double_equivalent_of_long_tensor(t.size())
7316*da0073e9SAndroid Build Coastguard Worker            else:
7317*da0073e9SAndroid Build Coastguard Worker                return t
7318*da0073e9SAndroid Build Coastguard Worker
7319*da0073e9SAndroid Build Coastguard Worker        def gen_long_tensor_constructor(constructor):
7320*da0073e9SAndroid Build Coastguard Worker            def long_tensor_constructor(*args, **kwargs):
7321*da0073e9SAndroid Build Coastguard Worker                cons = constructor(*args, **kwargs)
7322*da0073e9SAndroid Build Coastguard Worker                cons._apply(apply_to_cons)
7323*da0073e9SAndroid Build Coastguard Worker                return cons
7324*da0073e9SAndroid Build Coastguard Worker            long_tensor_constructor.__name__ = constructor.__name__
7325*da0073e9SAndroid Build Coastguard Worker            return long_tensor_constructor
7326*da0073e9SAndroid Build Coastguard Worker
7327*da0073e9SAndroid Build Coastguard Worker        def gen_long_tensor_input(input_size):
7328*da0073e9SAndroid Build Coastguard Worker            def input_func():
7329*da0073e9SAndroid Build Coastguard Worker                return double_equivalent_of_long_tensor(input_size)
7330*da0073e9SAndroid Build Coastguard Worker            return input_func
7331*da0073e9SAndroid Build Coastguard Worker
7332*da0073e9SAndroid Build Coastguard Worker        def reference_fn(i, p, m):
7333*da0073e9SAndroid Build Coastguard Worker            # For bad reasons this would create LongTensors that requires gradients
7334*da0073e9SAndroid Build Coastguard Worker            # Remove requires_grad to avoid this
7335*da0073e9SAndroid Build Coastguard Worker            for p in m.parameters():
7336*da0073e9SAndroid Build Coastguard Worker                p.requires_grad_(False)
7337*da0073e9SAndroid Build Coastguard Worker            m._apply(lambda t: t.long())
7338*da0073e9SAndroid Build Coastguard Worker            input = i.long()
7339*da0073e9SAndroid Build Coastguard Worker            out = m.forward(input)
7340*da0073e9SAndroid Build Coastguard Worker            return out
7341*da0073e9SAndroid Build Coastguard Worker
7342*da0073e9SAndroid Build Coastguard Worker        test_params['constructor'] = gen_long_tensor_constructor(test_params['constructor'])
7343*da0073e9SAndroid Build Coastguard Worker        test_params['input_fn'] = gen_long_tensor_input(test_params['input_size'])
7344*da0073e9SAndroid Build Coastguard Worker        test_params['reference_fn'] = reference_fn
7345*da0073e9SAndroid Build Coastguard Worker        test_params['check_forward_only'] = True
7346*da0073e9SAndroid Build Coastguard Worker        # Currently we don't support conv2d/conv3d for LongTensor in CUDA
7347*da0073e9SAndroid Build Coastguard Worker        test_params['test_cuda'] = False
7348*da0073e9SAndroid Build Coastguard Worker        test = NewModuleTest(**test_params)
7349*da0073e9SAndroid Build Coastguard Worker
7350*da0073e9SAndroid Build Coastguard Worker        add_test(test, decorator)
7351*da0073e9SAndroid Build Coastguard Worker
7352*da0073e9SAndroid Build Coastguard Workerfor test_params in criterion_tests:
7353*da0073e9SAndroid Build Coastguard Worker    if 'constructor' not in test_params:
7354*da0073e9SAndroid Build Coastguard Worker        name = test_params.pop('module_name')
7355*da0073e9SAndroid Build Coastguard Worker        test_params['constructor'] = getattr(nn, name)
7356*da0073e9SAndroid Build Coastguard Worker    test = CriterionTest(**test_params)
7357*da0073e9SAndroid Build Coastguard Worker    decorator = test_params.pop('decorator', None)
7358*da0073e9SAndroid Build Coastguard Worker    add_test(test, decorator)
7359*da0073e9SAndroid Build Coastguard Worker    if 'check_sum_reduction' in test_params:
7360*da0073e9SAndroid Build Coastguard Worker        desc = test_params.get('desc', None)
7361*da0073e9SAndroid Build Coastguard Worker        test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction'
7362*da0073e9SAndroid Build Coastguard Worker
7363*da0073e9SAndroid Build Coastguard Worker        def gen_sum_reduction_constructor(constructor):
7364*da0073e9SAndroid Build Coastguard Worker            def sum_reduction_constructor(*args, **kwargs):
7365*da0073e9SAndroid Build Coastguard Worker                cons = constructor(*args, reduction='sum', **kwargs)
7366*da0073e9SAndroid Build Coastguard Worker                return cons
7367*da0073e9SAndroid Build Coastguard Worker            sum_reduction_constructor.__name__ = constructor.__name__
7368*da0073e9SAndroid Build Coastguard Worker            return sum_reduction_constructor
7369*da0073e9SAndroid Build Coastguard Worker
7370*da0073e9SAndroid Build Coastguard Worker        test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor'])
7371*da0073e9SAndroid Build Coastguard Worker        test = CriterionTest(**test_params)
7372*da0073e9SAndroid Build Coastguard Worker        add_test(test, decorator)
7373*da0073e9SAndroid Build Coastguard Worker
7374*da0073e9SAndroid Build Coastguard Worker
7375*da0073e9SAndroid Build Coastguard Workerclass UnpoolingNet(nn.Module):
7376*da0073e9SAndroid Build Coastguard Worker    def __init__(self, pool, unpool):
7377*da0073e9SAndroid Build Coastguard Worker        super().__init__()
7378*da0073e9SAndroid Build Coastguard Worker        self.pool = pool
7379*da0073e9SAndroid Build Coastguard Worker        self.unpool = unpool
7380*da0073e9SAndroid Build Coastguard Worker
7381*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
7382*da0073e9SAndroid Build Coastguard Worker        return self.unpool(*self.pool(input))
7383*da0073e9SAndroid Build Coastguard Worker
7384*da0073e9SAndroid Build Coastguard Worker
7385*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest(
7386*da0073e9SAndroid Build Coastguard Worker    constructor=lambda: UnpoolingNet(
7387*da0073e9SAndroid Build Coastguard Worker        nn.MaxPool1d(2, return_indices=True),
7388*da0073e9SAndroid Build Coastguard Worker        nn.MaxUnpool1d(2)),
7389*da0073e9SAndroid Build Coastguard Worker    input_size=(1, 1, 4),
7390*da0073e9SAndroid Build Coastguard Worker    fullname='MaxUnpool1d_net',
7391*da0073e9SAndroid Build Coastguard Worker    default_dtype=torch.double,))
7392*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest(
7393*da0073e9SAndroid Build Coastguard Worker    constructor=lambda: UnpoolingNet(
7394*da0073e9SAndroid Build Coastguard Worker        nn.MaxPool2d(2, return_indices=True),
7395*da0073e9SAndroid Build Coastguard Worker        nn.MaxUnpool2d(2)),
7396*da0073e9SAndroid Build Coastguard Worker    input_size=(1, 1, 2, 4),
7397*da0073e9SAndroid Build Coastguard Worker    fullname='MaxUnpool2d_net',
7398*da0073e9SAndroid Build Coastguard Worker    default_dtype=torch.double,))
7399*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest(
7400*da0073e9SAndroid Build Coastguard Worker    constructor=lambda: UnpoolingNet(
7401*da0073e9SAndroid Build Coastguard Worker        nn.MaxPool3d(2, return_indices=True),
7402*da0073e9SAndroid Build Coastguard Worker        nn.MaxUnpool3d(2)),
7403*da0073e9SAndroid Build Coastguard Worker    input_size=(1, 1, 2, 4, 6),
7404*da0073e9SAndroid Build Coastguard Worker    fullname='MaxUnpool3d_net',
7405*da0073e9SAndroid Build Coastguard Worker    check_gradgrad=False,
7406*da0073e9SAndroid Build Coastguard Worker    default_dtype=torch.double,))
7407*da0073e9SAndroid Build Coastguard Worker
7408*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest(
7409*da0073e9SAndroid Build Coastguard Worker    constructor=lambda: UnpoolingNet(
7410*da0073e9SAndroid Build Coastguard Worker        nn.MaxPool1d(2, return_indices=True),
7411*da0073e9SAndroid Build Coastguard Worker        nn.MaxUnpool1d(2)),
7412*da0073e9SAndroid Build Coastguard Worker    input_size=(1, 4),
7413*da0073e9SAndroid Build Coastguard Worker    reference_fn=single_batch_reference_fn,
7414*da0073e9SAndroid Build Coastguard Worker    fullname='MaxUnpool1d_net_no_batch_dim',
7415*da0073e9SAndroid Build Coastguard Worker    default_dtype=torch.double,))
7416*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest(
7417*da0073e9SAndroid Build Coastguard Worker    constructor=lambda: UnpoolingNet(
7418*da0073e9SAndroid Build Coastguard Worker        nn.MaxPool2d(2, return_indices=True),
7419*da0073e9SAndroid Build Coastguard Worker        nn.MaxUnpool2d(2)),
7420*da0073e9SAndroid Build Coastguard Worker    input_size=(1, 2, 4),
7421*da0073e9SAndroid Build Coastguard Worker    reference_fn=single_batch_reference_fn,
7422*da0073e9SAndroid Build Coastguard Worker    fullname='MaxUnpool2d_net_no_batch_dim',
7423*da0073e9SAndroid Build Coastguard Worker    default_dtype=torch.double,))
7424*da0073e9SAndroid Build Coastguard Worker
7425*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest(
7426*da0073e9SAndroid Build Coastguard Worker    constructor=lambda: UnpoolingNet(
7427*da0073e9SAndroid Build Coastguard Worker        nn.MaxPool3d(2, return_indices=True),
7428*da0073e9SAndroid Build Coastguard Worker        nn.MaxUnpool3d(2)),
7429*da0073e9SAndroid Build Coastguard Worker    input_size=(1, 2, 4, 6),
7430*da0073e9SAndroid Build Coastguard Worker    reference_fn=single_batch_reference_fn,
7431*da0073e9SAndroid Build Coastguard Worker    fullname='MaxUnpool3d_net_no_batch_dim',
7432*da0073e9SAndroid Build Coastguard Worker    check_gradgrad=False,
7433*da0073e9SAndroid Build Coastguard Worker    default_dtype=torch.double,))
7434*da0073e9SAndroid Build Coastguard Worker
7435*da0073e9SAndroid Build Coastguard Workerclass _AdaptiveLogSoftmaxWithLoss(nn.AdaptiveLogSoftmaxWithLoss):
7436*da0073e9SAndroid Build Coastguard Worker    def __call__(self, input):
7437*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([0, 1, 4, 8]).to(input.device)
7438*da0073e9SAndroid Build Coastguard Worker        return nn.AdaptiveLogSoftmaxWithLoss.__call__(self, input, t).output
7439*da0073e9SAndroid Build Coastguard Worker
7440*da0073e9SAndroid Build Coastguard Workeradd_test(NewModuleTest(
7441*da0073e9SAndroid Build Coastguard Worker    constructor=lambda: _AdaptiveLogSoftmaxWithLoss(16, 10, [2, 6]),
7442*da0073e9SAndroid Build Coastguard Worker    input_size=(4, 16),
7443*da0073e9SAndroid Build Coastguard Worker    fullname='AdaptiveLogSoftmax',
7444*da0073e9SAndroid Build Coastguard Worker    with_tf32=True,
7445*da0073e9SAndroid Build Coastguard Worker    tf32_precision=0.005,
7446*da0073e9SAndroid Build Coastguard Worker    default_dtype=torch.double))
7447*da0073e9SAndroid Build Coastguard Worker
7448*da0073e9SAndroid Build Coastguard Worker
7449*da0073e9SAndroid Build Coastguard Worker# The following are helpers for TestNN.test_affine_*
7450*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available():
7451*da0073e9SAndroid Build Coastguard Worker    def device_():
7452*da0073e9SAndroid Build Coastguard Worker        return ['cpu', 'cuda']
7453*da0073e9SAndroid Build Coastguard Workerelse:
7454*da0073e9SAndroid Build Coastguard Worker    def device_():
7455*da0073e9SAndroid Build Coastguard Worker        return ['cpu']
7456*da0073e9SAndroid Build Coastguard Worker
7457*da0073e9SAndroid Build Coastguard Worker
7458*da0073e9SAndroid Build Coastguard Workerdef angle_rad_():
7459*da0073e9SAndroid Build Coastguard Worker    return [r * math.pi * 2 for r in [0.0, 0.5, 0.25, 0.125, random.random()]]
7460*da0073e9SAndroid Build Coastguard Worker
7461*da0073e9SAndroid Build Coastguard Worker
7462*da0073e9SAndroid Build Coastguard Workerdef axis_vector_():
7463*da0073e9SAndroid Build Coastguard Worker    t = (random.random(), random.random(), random.random())
7464*da0073e9SAndroid Build Coastguard Worker    l = sum(x ** 2 for x in t) ** 0.5
7465*da0073e9SAndroid Build Coastguard Worker
7466*da0073e9SAndroid Build Coastguard Worker    return [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), tuple(x / l for x in t)]
7467*da0073e9SAndroid Build Coastguard Worker
7468*da0073e9SAndroid Build Coastguard Worker
7469*da0073e9SAndroid Build Coastguard Workerdef input_size2d_():
7470*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 3, 5], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 3, 4]]
7471*da0073e9SAndroid Build Coastguard Worker
7472*da0073e9SAndroid Build Coastguard Worker
7473*da0073e9SAndroid Build Coastguard Workerdef output_size2d_():
7474*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 4, 3], [1, 1, 5, 5], [1, 1, 6, 6]]
7475*da0073e9SAndroid Build Coastguard Worker
7476*da0073e9SAndroid Build Coastguard Worker
7477*da0073e9SAndroid Build Coastguard Workerdef input_size2dsq_():
7478*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 6, 6]]
7479*da0073e9SAndroid Build Coastguard Worker
7480*da0073e9SAndroid Build Coastguard Worker
7481*da0073e9SAndroid Build Coastguard Workerdef output_size2dsq_():
7482*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 5, 5], [1, 1, 6, 6]]
7483*da0073e9SAndroid Build Coastguard Worker
7484*da0073e9SAndroid Build Coastguard Worker
7485*da0073e9SAndroid Build Coastguard Workerdef input_size3d_():
7486*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 2, 2, 2], [1, 1, 2, 3, 4], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 3, 4, 5]]
7487*da0073e9SAndroid Build Coastguard Worker
7488*da0073e9SAndroid Build Coastguard Worker
7489*da0073e9SAndroid Build Coastguard Workerdef input_size3dsq_():
7490*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 6, 6, 6]]
7491*da0073e9SAndroid Build Coastguard Worker
7492*da0073e9SAndroid Build Coastguard Worker
7493*da0073e9SAndroid Build Coastguard Workerdef output_size3dsq_():
7494*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]]
7495*da0073e9SAndroid Build Coastguard Worker
7496*da0073e9SAndroid Build Coastguard Worker
7497*da0073e9SAndroid Build Coastguard Workerdef output_size3d_():
7498*da0073e9SAndroid Build Coastguard Worker    return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5], [1, 1, 4, 3, 2], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]]
7499*da0073e9SAndroid Build Coastguard Worker
7500*da0073e9SAndroid Build Coastguard Worker
7501*da0073e9SAndroid Build Coastguard Workerdef _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad):
7502*da0073e9SAndroid Build Coastguard Worker    input_center = [(x - 1) / 2.0 for x in input_size]
7503*da0073e9SAndroid Build Coastguard Worker    output_center = [(x - 1) / 2.0 for x in output_size]
7504*da0073e9SAndroid Build Coastguard Worker
7505*da0073e9SAndroid Build Coastguard Worker    s = math.sin(angle_rad)
7506*da0073e9SAndroid Build Coastguard Worker    c = math.cos(angle_rad)
7507*da0073e9SAndroid Build Coastguard Worker
7508*da0073e9SAndroid Build Coastguard Worker    intrans_ary = np.array([
7509*da0073e9SAndroid Build Coastguard Worker        [1, 0, input_center[2]],
7510*da0073e9SAndroid Build Coastguard Worker        [0, 1, input_center[3]],
7511*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1],
7512*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7513*da0073e9SAndroid Build Coastguard Worker
7514*da0073e9SAndroid Build Coastguard Worker    inscale_ary = np.array([
7515*da0073e9SAndroid Build Coastguard Worker        [input_center[2], 0, 0],
7516*da0073e9SAndroid Build Coastguard Worker        [0, input_center[3], 0],
7517*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1],
7518*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7519*da0073e9SAndroid Build Coastguard Worker
7520*da0073e9SAndroid Build Coastguard Worker    rotation_ary = np.array([
7521*da0073e9SAndroid Build Coastguard Worker        [c, -s, 0],
7522*da0073e9SAndroid Build Coastguard Worker        [s, c, 0],
7523*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1],
7524*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7525*da0073e9SAndroid Build Coastguard Worker
7526*da0073e9SAndroid Build Coastguard Worker    outscale_ary = np.array([
7527*da0073e9SAndroid Build Coastguard Worker        [1.0 / output_center[2], 0, 0],
7528*da0073e9SAndroid Build Coastguard Worker        [0, 1.0 / output_center[3], 0],
7529*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1],
7530*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7531*da0073e9SAndroid Build Coastguard Worker
7532*da0073e9SAndroid Build Coastguard Worker    outtrans_ary = np.array([
7533*da0073e9SAndroid Build Coastguard Worker        [1, 0, -output_center[2]],
7534*da0073e9SAndroid Build Coastguard Worker        [0, 1, -output_center[3]],
7535*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1],
7536*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7537*da0073e9SAndroid Build Coastguard Worker
7538*da0073e9SAndroid Build Coastguard Worker    reorder_ary = np.array([
7539*da0073e9SAndroid Build Coastguard Worker        [0, 1, 0],
7540*da0073e9SAndroid Build Coastguard Worker        [1, 0, 0],
7541*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1],
7542*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7543*da0073e9SAndroid Build Coastguard Worker
7544*da0073e9SAndroid Build Coastguard Worker    transform_ary = np.dot(np.dot(np.dot(np.dot(
7545*da0073e9SAndroid Build Coastguard Worker        intrans_ary,
7546*da0073e9SAndroid Build Coastguard Worker        inscale_ary),
7547*da0073e9SAndroid Build Coastguard Worker        rotation_ary.T),
7548*da0073e9SAndroid Build Coastguard Worker        outscale_ary),
7549*da0073e9SAndroid Build Coastguard Worker        outtrans_ary)
7550*da0073e9SAndroid Build Coastguard Worker    grid_ary = np.dot(np.dot(np.dot(reorder_ary, rotation_ary.T), outscale_ary), outtrans_ary)
7551*da0073e9SAndroid Build Coastguard Worker
7552*da0073e9SAndroid Build Coastguard Worker    transform_tensor = torch.from_numpy(rotation_ary).to(device, torch.float32)
7553*da0073e9SAndroid Build Coastguard Worker    transform_tensor = transform_tensor[:2].unsqueeze(0)
7554*da0073e9SAndroid Build Coastguard Worker
7555*da0073e9SAndroid Build Coastguard Worker    return transform_tensor, transform_ary, grid_ary
7556*da0073e9SAndroid Build Coastguard Worker
7557*da0073e9SAndroid Build Coastguard Worker
7558*da0073e9SAndroid Build Coastguard Workerdef _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector):
7559*da0073e9SAndroid Build Coastguard Worker    input_center = [(x - 1) / 2.0 for x in input_size]
7560*da0073e9SAndroid Build Coastguard Worker    output_center = [(x - 1) / 2.0 for x in output_size]
7561*da0073e9SAndroid Build Coastguard Worker
7562*da0073e9SAndroid Build Coastguard Worker    s = math.sin(angle_rad)
7563*da0073e9SAndroid Build Coastguard Worker    c = math.cos(angle_rad)
7564*da0073e9SAndroid Build Coastguard Worker    c1 = 1 - c
7565*da0073e9SAndroid Build Coastguard Worker
7566*da0073e9SAndroid Build Coastguard Worker    intrans_ary = np.array([
7567*da0073e9SAndroid Build Coastguard Worker        [1, 0, 0, input_center[2]],
7568*da0073e9SAndroid Build Coastguard Worker        [0, 1, 0, input_center[3]],
7569*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1, input_center[4]],
7570*da0073e9SAndroid Build Coastguard Worker        [0, 0, 0, 1],
7571*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7572*da0073e9SAndroid Build Coastguard Worker
7573*da0073e9SAndroid Build Coastguard Worker    inscale_ary = np.array([
7574*da0073e9SAndroid Build Coastguard Worker        [input_center[2], 0, 0, 0],
7575*da0073e9SAndroid Build Coastguard Worker        [0, input_center[3], 0, 0],
7576*da0073e9SAndroid Build Coastguard Worker        [0, 0, input_center[4], 0],
7577*da0073e9SAndroid Build Coastguard Worker        [0, 0, 0, 1],
7578*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7579*da0073e9SAndroid Build Coastguard Worker
7580*da0073e9SAndroid Build Coastguard Worker    l, m, n = axis_vector
7581*da0073e9SAndroid Build Coastguard Worker    scipyRotation_ary = np.array([
7582*da0073e9SAndroid Build Coastguard Worker        [l * l * c1 + c, m * l * c1 - n * s, n * l * c1 + m * s, 0],
7583*da0073e9SAndroid Build Coastguard Worker        [l * m * c1 + n * s, m * m * c1 + c, n * m * c1 - l * s, 0],
7584*da0073e9SAndroid Build Coastguard Worker        [l * n * c1 - m * s, m * n * c1 + l * s, n * n * c1 + c, 0],
7585*da0073e9SAndroid Build Coastguard Worker        [0, 0, 0, 1],
7586*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7587*da0073e9SAndroid Build Coastguard Worker
7588*da0073e9SAndroid Build Coastguard Worker    z, y, x = axis_vector
7589*da0073e9SAndroid Build Coastguard Worker    torchRotation_ary = np.array([
7590*da0073e9SAndroid Build Coastguard Worker        [x * x * c1 + c, y * x * c1 - z * s, z * x * c1 + y * s, 0],
7591*da0073e9SAndroid Build Coastguard Worker        [x * y * c1 + z * s, y * y * c1 + c, z * y * c1 - x * s, 0],
7592*da0073e9SAndroid Build Coastguard Worker        [x * z * c1 - y * s, y * z * c1 + x * s, z * z * c1 + c, 0],
7593*da0073e9SAndroid Build Coastguard Worker        [0, 0, 0, 1],
7594*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7595*da0073e9SAndroid Build Coastguard Worker
7596*da0073e9SAndroid Build Coastguard Worker    outscale_ary = np.array([
7597*da0073e9SAndroid Build Coastguard Worker        [1.0 / output_center[2], 0, 0, 0],
7598*da0073e9SAndroid Build Coastguard Worker        [0, 1.0 / output_center[3], 0, 0],
7599*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1.0 / output_center[4], 0],
7600*da0073e9SAndroid Build Coastguard Worker        [0, 0, 0, 1],
7601*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7602*da0073e9SAndroid Build Coastguard Worker
7603*da0073e9SAndroid Build Coastguard Worker    outtrans_ary = np.array([
7604*da0073e9SAndroid Build Coastguard Worker        [1, 0, 0, -output_center[2]],
7605*da0073e9SAndroid Build Coastguard Worker        [0, 1, 0, -output_center[3]],
7606*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1, -output_center[4]],
7607*da0073e9SAndroid Build Coastguard Worker        [0, 0, 0, 1],
7608*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7609*da0073e9SAndroid Build Coastguard Worker
7610*da0073e9SAndroid Build Coastguard Worker    reorder_ary = np.array([
7611*da0073e9SAndroid Build Coastguard Worker        [0, 0, 1, 0],
7612*da0073e9SAndroid Build Coastguard Worker        [0, 1, 0, 0],
7613*da0073e9SAndroid Build Coastguard Worker        [1, 0, 0, 0],
7614*da0073e9SAndroid Build Coastguard Worker        [0, 0, 0, 1],
7615*da0073e9SAndroid Build Coastguard Worker    ], dtype=np.float64)
7616*da0073e9SAndroid Build Coastguard Worker
7617*da0073e9SAndroid Build Coastguard Worker    transform_ary = np.dot(np.dot(np.dot(np.dot(
7618*da0073e9SAndroid Build Coastguard Worker        intrans_ary,
7619*da0073e9SAndroid Build Coastguard Worker        inscale_ary),
7620*da0073e9SAndroid Build Coastguard Worker        np.linalg.inv(scipyRotation_ary)),
7621*da0073e9SAndroid Build Coastguard Worker        outscale_ary),
7622*da0073e9SAndroid Build Coastguard Worker        outtrans_ary)
7623*da0073e9SAndroid Build Coastguard Worker    grid_ary = np.dot(np.dot(np.dot(reorder_ary, np.linalg.inv(scipyRotation_ary)), outscale_ary), outtrans_ary)
7624*da0073e9SAndroid Build Coastguard Worker
7625*da0073e9SAndroid Build Coastguard Worker    transform_tensor = torch.from_numpy(torchRotation_ary).to(device, torch.float32)
7626*da0073e9SAndroid Build Coastguard Worker    transform_tensor = transform_tensor[:3].unsqueeze(0)
7627*da0073e9SAndroid Build Coastguard Worker
7628*da0073e9SAndroid Build Coastguard Worker    return transform_tensor, transform_ary, grid_ary
7629*da0073e9SAndroid Build Coastguard Worker# end TestNN.test_affine_* helpers
7630*da0073e9SAndroid Build Coastguard Worker
7631*da0073e9SAndroid Build Coastguard Worker
7632*da0073e9SAndroid Build Coastguard Workerclass TestNNDeviceType(NNTestCase):
7633*da0073e9SAndroid Build Coastguard Worker    def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float):
7634*da0073e9SAndroid Build Coastguard Worker        # default case track_running_stats=False
7635*da0073e9SAndroid Build Coastguard Worker        b, c = input.size(0), input.size(1)
7636*da0073e9SAndroid Build Coastguard Worker        input_var = input.to(device=device, dtype=dtype).requires_grad_()
7637*da0073e9SAndroid Build Coastguard Worker
7638*da0073e9SAndroid Build Coastguard Worker        IN = cls(c, eps=0).to(device, dtype)
7639*da0073e9SAndroid Build Coastguard Worker
7640*da0073e9SAndroid Build Coastguard Worker        output = IN(input_var)
7641*da0073e9SAndroid Build Coastguard Worker        out_reshaped = output.view(b * c, -1)
7642*da0073e9SAndroid Build Coastguard Worker
7643*da0073e9SAndroid Build Coastguard Worker        mean = out_reshaped.mean(1)
7644*da0073e9SAndroid Build Coastguard Worker        var = out_reshaped.var(1, unbiased=False)
7645*da0073e9SAndroid Build Coastguard Worker
7646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.abs(mean.data).mean(), 0, atol=1e-5, rtol=0)
7647*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.abs(var.data).mean(), 1, atol=1e-5, rtol=0)
7648*da0073e9SAndroid Build Coastguard Worker
7649*da0073e9SAndroid Build Coastguard Worker        # check that eval mode doesn't change behavior
7650*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.randn_like(output)
7651*da0073e9SAndroid Build Coastguard Worker        res1 = output.data.clone()
7652*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_out)
7653*da0073e9SAndroid Build Coastguard Worker        grad1 = input_var.grad.data.clone()
7654*da0073e9SAndroid Build Coastguard Worker
7655*da0073e9SAndroid Build Coastguard Worker        IN.eval()
7656*da0073e9SAndroid Build Coastguard Worker        output = IN(input_var)
7657*da0073e9SAndroid Build Coastguard Worker        input_var.grad = None
7658*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_out)
7659*da0073e9SAndroid Build Coastguard Worker        res2 = output.data
7660*da0073e9SAndroid Build Coastguard Worker        grad2 = input_var.grad.data
7661*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
7662*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad1, grad2)
7663*da0073e9SAndroid Build Coastguard Worker
7664*da0073e9SAndroid Build Coastguard Worker        # If track_running_stats=True and momentum=1, running_mean/var should be
7665*da0073e9SAndroid Build Coastguard Worker        # equal to mean/var of the input (with unbias correction)
7666*da0073e9SAndroid Build Coastguard Worker        IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype)
7667*da0073e9SAndroid Build Coastguard Worker
7668*da0073e9SAndroid Build Coastguard Worker        output = IN(input_var)
7669*da0073e9SAndroid Build Coastguard Worker
7670*da0073e9SAndroid Build Coastguard Worker        input_reshaped = input_var.transpose(1, 0).reshape(c, -1)
7671*da0073e9SAndroid Build Coastguard Worker        mean = input_reshaped.mean(1)
7672*da0073e9SAndroid Build Coastguard Worker
7673*da0073e9SAndroid Build Coastguard Worker        input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1)
7674*da0073e9SAndroid Build Coastguard Worker        var = input_reshaped.var(2, unbiased=True)[:, :]
7675*da0073e9SAndroid Build Coastguard Worker
7676*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, atol=1e-5, rtol=0)
7677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, atol=1e-5, rtol=0)
7678*da0073e9SAndroid Build Coastguard Worker
7679*da0073e9SAndroid Build Coastguard Worker        # in eval mode, adding X * std to a channel in input should make the
7680*da0073e9SAndroid Build Coastguard Worker        # corresponding channel in output have mean X
7681*da0073e9SAndroid Build Coastguard Worker        IN.eval()
7682*da0073e9SAndroid Build Coastguard Worker        delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype)
7683*da0073e9SAndroid Build Coastguard Worker        delta = delta.view(-1, *[1 for _ in range(2, input.dim())])
7684*da0073e9SAndroid Build Coastguard Worker        output = IN(input_var + delta)
7685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c, dtype=dtype))
7686*da0073e9SAndroid Build Coastguard Worker
7687*da0073e9SAndroid Build Coastguard Worker    def _test_InstanceNorm_cuda_half(self, cls, input, device):
7688*da0073e9SAndroid Build Coastguard Worker        # THNN
7689*da0073e9SAndroid Build Coastguard Worker        input = input.to(device=device, dtype=torch.half).random_(1, 10).requires_grad_(True)
7690*da0073e9SAndroid Build Coastguard Worker        m = cls(input.size(1), affine=True, track_running_stats=True).to(device, torch.half)
7691*da0073e9SAndroid Build Coastguard Worker        thnn_output = m(input)
7692*da0073e9SAndroid Build Coastguard Worker        thnn_output.sum().backward()
7693*da0073e9SAndroid Build Coastguard Worker        thnn_input_grad = input.grad.data.clone()
7694*da0073e9SAndroid Build Coastguard Worker        self.assertEqualTypeString(thnn_output, input)
7695*da0073e9SAndroid Build Coastguard Worker        # cuDNN
7696*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDNN:
7697*da0073e9SAndroid Build Coastguard Worker            input.grad = None
7698*da0073e9SAndroid Build Coastguard Worker            m = m.float()
7699*da0073e9SAndroid Build Coastguard Worker            cudnn_output = m(input)
7700*da0073e9SAndroid Build Coastguard Worker            cudnn_output.sum().backward()
7701*da0073e9SAndroid Build Coastguard Worker            cudnn_input_grad = input.grad.data.clone()
7702*da0073e9SAndroid Build Coastguard Worker            self.assertEqualTypeString(cudnn_output, input)
7703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cudnn_output, thnn_output, atol=1e-4, rtol=0)
7704*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cudnn_input_grad, thnn_input_grad, atol=1e-3, rtol=0)
7705*da0073e9SAndroid Build Coastguard Worker
7706*da0073e9SAndroid Build Coastguard Worker    def _test_LayerNorm_general(self, device, dtype=torch.float):
7707*da0073e9SAndroid Build Coastguard Worker        for i in range(2, 6):
7708*da0073e9SAndroid Build Coastguard Worker            shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist()
7709*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
7710*da0073e9SAndroid Build Coastguard Worker            normalized_ndim = random.randint(1, i - 1)  # inclusive
7711*da0073e9SAndroid Build Coastguard Worker            normalized_shape = shape[-normalized_ndim:]
7712*da0073e9SAndroid Build Coastguard Worker            unnormalized_shape = shape[:-normalized_ndim]
7713*da0073e9SAndroid Build Coastguard Worker
7714*da0073e9SAndroid Build Coastguard Worker            # test that LN normalizes to mean 0 and stddev 1
7715*da0073e9SAndroid Build Coastguard Worker            ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype)
7716*da0073e9SAndroid Build Coastguard Worker            ln.weight.data.fill_(1)
7717*da0073e9SAndroid Build Coastguard Worker            ln.bias.data.fill_(0)
7718*da0073e9SAndroid Build Coastguard Worker            output = ln(x)
7719*da0073e9SAndroid Build Coastguard Worker            out_reshaped = output.view(*(unnormalized_shape + [-1]))
7720*da0073e9SAndroid Build Coastguard Worker            mean = out_reshaped.mean(-1)
7721*da0073e9SAndroid Build Coastguard Worker            var = out_reshaped.var(-1, unbiased=False)
7722*da0073e9SAndroid Build Coastguard Worker
7723*da0073e9SAndroid Build Coastguard Worker            delta = 1e-1 if (dtype == torch.bfloat16 or dtype == torch.half) else 1e-5
7724*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(mean.data).mean(), 0, atol=delta, rtol=0)
7725*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(var.data).mean(), 1, atol=delta, rtol=0)
7726*da0073e9SAndroid Build Coastguard Worker
7727*da0073e9SAndroid Build Coastguard Worker            # test that LN applies weight and bias correctly
7728*da0073e9SAndroid Build Coastguard Worker            scale, bias = torch.empty(2).uniform_(0.2, 2).tolist()
7729*da0073e9SAndroid Build Coastguard Worker            ln.weight.data.fill_(scale)
7730*da0073e9SAndroid Build Coastguard Worker            ln.bias.data.fill_(bias)
7731*da0073e9SAndroid Build Coastguard Worker            output = ln(x)
7732*da0073e9SAndroid Build Coastguard Worker            out_reshaped = output.view(*(unnormalized_shape + [-1]))
7733*da0073e9SAndroid Build Coastguard Worker            mean = out_reshaped.mean(-1)
7734*da0073e9SAndroid Build Coastguard Worker            var = out_reshaped.var(-1, unbiased=False)
7735*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(mean.data).mean(), bias, atol=delta, rtol=0)
7736*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(var.data).mean(), scale ** 2, atol=delta, rtol=0)
7737*da0073e9SAndroid Build Coastguard Worker
7738*da0073e9SAndroid Build Coastguard Worker        bad_norm_shape_input_shape = {
7739*da0073e9SAndroid Build Coastguard Worker            (): (),
7740*da0073e9SAndroid Build Coastguard Worker            (2, 3): (3,),
7741*da0073e9SAndroid Build Coastguard Worker            (2,): (1, 2, 3),
7742*da0073e9SAndroid Build Coastguard Worker            (10,): (2, 3),
7743*da0073e9SAndroid Build Coastguard Worker            10: (2, 3),
7744*da0073e9SAndroid Build Coastguard Worker        }
7745*da0073e9SAndroid Build Coastguard Worker        for norm_shape, input_shape in bad_norm_shape_input_shape.items():
7746*da0073e9SAndroid Build Coastguard Worker            ln = nn.LayerNorm(norm_shape)
7747*da0073e9SAndroid Build Coastguard Worker            input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10)
7748*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: ln(input))
7749*da0073e9SAndroid Build Coastguard Worker
7750*da0073e9SAndroid Build Coastguard Worker    def _test_LayerNorm_cuda_half(self, device):
7751*da0073e9SAndroid Build Coastguard Worker        input = torch.empty(2, 3, 3, 2, device=device, dtype=torch.half).random_(1, 10).requires_grad_(True)
7752*da0073e9SAndroid Build Coastguard Worker        m = nn.LayerNorm([3, 2]).to(device, torch.half)
7753*da0073e9SAndroid Build Coastguard Worker        output = m(input)
7754*da0073e9SAndroid Build Coastguard Worker        output.sum().backward()
7755*da0073e9SAndroid Build Coastguard Worker        self.assertEqualTypeString(output, input)
7756*da0073e9SAndroid Build Coastguard Worker
7757*da0073e9SAndroid Build Coastguard Worker    def _test_LayerNorm_cpu_mixed_dtype(self, device, dtype):
7758*da0073e9SAndroid Build Coastguard Worker        for elementwise_affine in [True, False]:
7759*da0073e9SAndroid Build Coastguard Worker            # layer norm input shape is normalized to m x n, cpu vectorized on n,
7760*da0073e9SAndroid Build Coastguard Worker            # so make sure n exceeds vector length
7761*da0073e9SAndroid Build Coastguard Worker            input = torch.empty(2, 3, 11, 3, device=device, dtype=dtype).random_(1, 10)
7762*da0073e9SAndroid Build Coastguard Worker            m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, dtype)
7763*da0073e9SAndroid Build Coastguard Worker
7764*da0073e9SAndroid Build Coastguard Worker            # fp32
7765*da0073e9SAndroid Build Coastguard Worker            m_fp32 = deepcopy(m).to(device, torch.float)
7766*da0073e9SAndroid Build Coastguard Worker            x_fp32 = input.clone().detach().float().requires_grad_()
7767*da0073e9SAndroid Build Coastguard Worker            out_fp32 = m_fp32(x_fp32)
7768*da0073e9SAndroid Build Coastguard Worker            out_fp32.sum().backward()
7769*da0073e9SAndroid Build Coastguard Worker
7770*da0073e9SAndroid Build Coastguard Worker            # bf16/half
7771*da0073e9SAndroid Build Coastguard Worker            m_bf16 = deepcopy(m)
7772*da0073e9SAndroid Build Coastguard Worker            x_bf16 = input.clone().detach().requires_grad_()
7773*da0073e9SAndroid Build Coastguard Worker            out_bf16 = m_bf16(x_bf16)
7774*da0073e9SAndroid Build Coastguard Worker            out_bf16.sum().backward()
7775*da0073e9SAndroid Build Coastguard Worker
7776*da0073e9SAndroid Build Coastguard Worker            # bf16/half mixed type
7777*da0073e9SAndroid Build Coastguard Worker            m_mix = deepcopy(m).to(device, torch.float)
7778*da0073e9SAndroid Build Coastguard Worker            x_mix = input.clone().detach().requires_grad_()
7779*da0073e9SAndroid Build Coastguard Worker            out_mix = m_mix(x_mix)
7780*da0073e9SAndroid Build Coastguard Worker            out_mix.sum().backward()
7781*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_fp32.to(dtype=dtype), out_bf16)
7782*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_fp32.to(dtype=dtype), out_mix)
7783*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_fp32.grad.to(dtype=dtype), x_bf16.grad, atol=1e-1, rtol=1e-1)
7784*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_fp32.grad.to(dtype=dtype), x_mix.grad, atol=1e-1, rtol=1e-1)
7785*da0073e9SAndroid Build Coastguard Worker
7786*da0073e9SAndroid Build Coastguard Worker    def _test_GroupNorm_general(self, device, dtype=torch.float):
7787*da0073e9SAndroid Build Coastguard Worker        good_shape_g = {
7788*da0073e9SAndroid Build Coastguard Worker            (1, 2, 3, 4): 2,
7789*da0073e9SAndroid Build Coastguard Worker            (2, 3, 10): 3,
7790*da0073e9SAndroid Build Coastguard Worker            (3, 1, 1, 1, 2): 1,
7791*da0073e9SAndroid Build Coastguard Worker            (2, 6, 4, 2, 2): 3,
7792*da0073e9SAndroid Build Coastguard Worker            (1, 256, 1, 1): 32,
7793*da0073e9SAndroid Build Coastguard Worker        }
7794*da0073e9SAndroid Build Coastguard Worker        for shape_g, grad in product(good_shape_g.items(), [True, False]):
7795*da0073e9SAndroid Build Coastguard Worker            shape, g = shape_g
7796*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10)
7797*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_(grad)
7798*da0073e9SAndroid Build Coastguard Worker            b = shape[0]
7799*da0073e9SAndroid Build Coastguard Worker            c = shape[1]
7800*da0073e9SAndroid Build Coastguard Worker
7801*da0073e9SAndroid Build Coastguard Worker            # test that GN normalizes to mean 0 and stddev 1
7802*da0073e9SAndroid Build Coastguard Worker            gn = nn.GroupNorm(g, c, eps=0).to(device, dtype)
7803*da0073e9SAndroid Build Coastguard Worker            gn.weight.data.fill_(1)
7804*da0073e9SAndroid Build Coastguard Worker            gn.bias.data.fill_(0)
7805*da0073e9SAndroid Build Coastguard Worker            output = gn(x)
7806*da0073e9SAndroid Build Coastguard Worker            out_reshaped = output.view(b, g, -1)
7807*da0073e9SAndroid Build Coastguard Worker            mean = out_reshaped.mean(-1)
7808*da0073e9SAndroid Build Coastguard Worker            var = out_reshaped.var(-1, unbiased=False)
7809*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(mean).mean(), 0, atol=1e-5, rtol=0)
7810*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(var).mean(), 1, atol=1e-5, rtol=0)
7811*da0073e9SAndroid Build Coastguard Worker
7812*da0073e9SAndroid Build Coastguard Worker            output.backward(torch.randn_like(output))
7813*da0073e9SAndroid Build Coastguard Worker            if output.is_cuda:
7814*da0073e9SAndroid Build Coastguard Worker                torch.cuda.synchronize()
7815*da0073e9SAndroid Build Coastguard Worker
7816*da0073e9SAndroid Build Coastguard Worker            # test that GN applies weight and bias correctly
7817*da0073e9SAndroid Build Coastguard Worker            scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
7818*da0073e9SAndroid Build Coastguard Worker            bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2)
7819*da0073e9SAndroid Build Coastguard Worker            gn.weight.data.copy_(scale)
7820*da0073e9SAndroid Build Coastguard Worker            gn.bias.data.copy_(bias)
7821*da0073e9SAndroid Build Coastguard Worker            output = gn(x)
7822*da0073e9SAndroid Build Coastguard Worker            out_reshaped = output.view(b, c, -1)
7823*da0073e9SAndroid Build Coastguard Worker            out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1)
7824*da0073e9SAndroid Build Coastguard Worker            out_normed_reshaped = out_normed.view(b, g, -1)
7825*da0073e9SAndroid Build Coastguard Worker            mean = out_normed_reshaped.mean(-1)
7826*da0073e9SAndroid Build Coastguard Worker            var = out_normed_reshaped.var(-1, unbiased=False)
7827*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(mean).mean(), 0, atol=1e-5, rtol=0)
7828*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.abs(var).mean(), 1, atol=1e-5, rtol=0)
7829*da0073e9SAndroid Build Coastguard Worker
7830*da0073e9SAndroid Build Coastguard Worker        bad_shape_g = {
7831*da0073e9SAndroid Build Coastguard Worker            (1, 2, 3, 4): 3,
7832*da0073e9SAndroid Build Coastguard Worker            (2, 3, 10): 2,
7833*da0073e9SAndroid Build Coastguard Worker            (3, 1, 1, 1, 2): 10,
7834*da0073e9SAndroid Build Coastguard Worker            (2, 6, 4, 2, 2): 4,
7835*da0073e9SAndroid Build Coastguard Worker        }
7836*da0073e9SAndroid Build Coastguard Worker        for shape, g in bad_shape_g.items():
7837*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
7838*da0073e9SAndroid Build Coastguard Worker                gn = nn.GroupNorm(g, shape[1])
7839*da0073e9SAndroid Build Coastguard Worker
7840*da0073e9SAndroid Build Coastguard Worker    def _test_GroupNorm_cuda_half(self):
7841*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10)
7842*da0073e9SAndroid Build Coastguard Worker        m = nn.GroupNorm(2, 4).to("cuda", torch.half)
7843*da0073e9SAndroid Build Coastguard Worker        output = m(input)
7844*da0073e9SAndroid Build Coastguard Worker        output.sum().backward()
7845*da0073e9SAndroid Build Coastguard Worker        self.assertEqualTypeString(output, input)
7846*da0073e9SAndroid Build Coastguard Worker
7847*da0073e9SAndroid Build Coastguard Worker    def _test_GroupNorm_cpu_mixed_dtype(self):
7848*da0073e9SAndroid Build Coastguard Worker        def helper(self, size, groups, memory_format, dtype):
7849*da0073e9SAndroid Build Coastguard Worker            channels = size[1]
7850*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(size).cpu().to(dtype=dtype)
7851*da0073e9SAndroid Build Coastguard Worker            input_bf1 = input.contiguous(memory_format=memory_format).detach().requires_grad_(True)
7852*da0073e9SAndroid Build Coastguard Worker            input_bf2 = input_bf1.clone().detach().requires_grad_(True)
7853*da0073e9SAndroid Build Coastguard Worker            input_f = input_bf1.float().detach().requires_grad_(True)
7854*da0073e9SAndroid Build Coastguard Worker            m_bf = nn.GroupNorm(groups, channels).cpu().to(dtype=dtype)
7855*da0073e9SAndroid Build Coastguard Worker            m_f = deepcopy(m_bf).float()
7856*da0073e9SAndroid Build Coastguard Worker            m_f2 = deepcopy(m_f)
7857*da0073e9SAndroid Build Coastguard Worker            # bfloat16 input and bfloat16 parameters
7858*da0073e9SAndroid Build Coastguard Worker            out = m_bf(input_bf1)
7859*da0073e9SAndroid Build Coastguard Worker            # bfloat16 input and float parameters
7860*da0073e9SAndroid Build Coastguard Worker            out2 = m_f(input_bf2)
7861*da0073e9SAndroid Build Coastguard Worker            # float input and float parameters
7862*da0073e9SAndroid Build Coastguard Worker            out3 = m_f2(input_f)
7863*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2, atol=5e-3, rtol=5e-3)
7864*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out2.float(), out3, atol=5e-3, rtol=5e-3)
7865*da0073e9SAndroid Build Coastguard Worker            grad_out = torch.randn(out2.shape).cpu().to(dtype=dtype)
7866*da0073e9SAndroid Build Coastguard Worker            grad_out_bf1 = grad_out.contiguous(memory_format=memory_format).detach().requires_grad_(True)
7867*da0073e9SAndroid Build Coastguard Worker            grad_out_bf2 = grad_out_bf1.clone().detach().requires_grad_(True)
7868*da0073e9SAndroid Build Coastguard Worker            grad_out_f = grad_out_bf2.clone().float().detach().requires_grad_(True)
7869*da0073e9SAndroid Build Coastguard Worker            # bfloat16/half input grad and float parameters
7870*da0073e9SAndroid Build Coastguard Worker            out2.backward(grad_out_bf2, retain_graph=True)
7871*da0073e9SAndroid Build Coastguard Worker            # float input grad and float parameters
7872*da0073e9SAndroid Build Coastguard Worker            out3.backward(grad_out_f, retain_graph=True)
7873*da0073e9SAndroid Build Coastguard Worker            # bfloat16/half input grad and bfloat16/half parameters
7874*da0073e9SAndroid Build Coastguard Worker            out.backward(grad_out_bf1, retain_graph=True)
7875*da0073e9SAndroid Build Coastguard Worker            # Need higher tolerances atol=1e-4 and rtol=1e-4 on macos
7876*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_f.weight.grad, m_f2.weight.grad, atol=1e-4, rtol=1e-4)
7877*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_f.bias.grad, m_f2.bias.grad, atol=1e-5, rtol=1e-5)
7878*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_bf2.grad.float(), input_f.grad, atol=5e-5, rtol=5e-3)
7879*da0073e9SAndroid Build Coastguard Worker            # Full bf16/half has lower precision compared with mixed bf16/half and fp32.
7880*da0073e9SAndroid Build Coastguard Worker            # Use Amp to keep module parameters in acc dtype, i.e. float, for better numerical stability
7881*da0073e9SAndroid Build Coastguard Worker            atol = None
7882*da0073e9SAndroid Build Coastguard Worker            rtol = None
7883*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
7884*da0073e9SAndroid Build Coastguard Worker                atol = 1e-2
7885*da0073e9SAndroid Build Coastguard Worker                rtol = 1.2e-1
7886*da0073e9SAndroid Build Coastguard Worker            else:
7887*da0073e9SAndroid Build Coastguard Worker                assert dtype == torch.half
7888*da0073e9SAndroid Build Coastguard Worker                atol = 5e-3
7889*da0073e9SAndroid Build Coastguard Worker                rtol = 1.5e-2
7890*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_bf.weight.grad, m_f.weight.grad.to(dtype=dtype), atol=atol, rtol=rtol)
7891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_bf.bias.grad, m_f.bias.grad.to(dtype=dtype), atol=atol, rtol=rtol)
7892*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_bf1.grad, input_bf2.grad, atol=atol, rtol=rtol)
7893*da0073e9SAndroid Build Coastguard Worker
7894*da0073e9SAndroid Build Coastguard Worker        cl_formats = {4: torch.channels_last, 5: torch.channels_last_3d}
7895*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.half]:
7896*da0073e9SAndroid Build Coastguard Worker            for shape, g in [((1, 8, 4, 3), 2), ((1, 8, 3, 4), 4),
7897*da0073e9SAndroid Build Coastguard Worker                             ((4, 40, 40, 40), 2), ((4, 8, 40, 40), 4),
7898*da0073e9SAndroid Build Coastguard Worker                             ((1, 8, 40, 40), 4), ((1, 8, 40, 40), 2),
7899*da0073e9SAndroid Build Coastguard Worker                             ((1, 8, 50, 50), 2), ((1, 8, 50, 50), 4),
7900*da0073e9SAndroid Build Coastguard Worker                             ((1, 40, 50, 50), 2), ((1, 9, 3, 4, 5), 3),
7901*da0073e9SAndroid Build Coastguard Worker                             ((1, 60, 10, 10, 10), 3), ((1, 9, 10, 50, 50), 3),
7902*da0073e9SAndroid Build Coastguard Worker                             ((1, 60, 10, 50, 50), 3), ((1, 8, 65, 55), 2),
7903*da0073e9SAndroid Build Coastguard Worker                             ((1, 3, 65, 55), 1), ((1, 3, 20, 20), 1)]:
7904*da0073e9SAndroid Build Coastguard Worker                for is_cl in [False, True]:
7905*da0073e9SAndroid Build Coastguard Worker                    format = cl_formats[len(shape)] if is_cl else torch.contiguous_format
7906*da0073e9SAndroid Build Coastguard Worker                    helper(self, shape, g, format, dtype)
7907*da0073e9SAndroid Build Coastguard Worker
7908*da0073e9SAndroid Build Coastguard Worker    def _test_module_empty_inputs(self, module, inputs):
7909*da0073e9SAndroid Build Coastguard Worker        for _inp in inputs:
7910*da0073e9SAndroid Build Coastguard Worker            _inp.requires_grad_(True)
7911*da0073e9SAndroid Build Coastguard Worker        out = module(*inputs)
7912*da0073e9SAndroid Build Coastguard Worker        gO = torch.rand_like(out)
7913*da0073e9SAndroid Build Coastguard Worker        out.backward(gO)
7914*da0073e9SAndroid Build Coastguard Worker
7915*da0073e9SAndroid Build Coastguard Worker        for p in module.parameters():
7916*da0073e9SAndroid Build Coastguard Worker            if p.requires_grad:
7917*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p.grad, torch.zeros_like(p.grad))
7918*da0073e9SAndroid Build Coastguard Worker
7919*da0073e9SAndroid Build Coastguard Worker        for _inp in inputs:
7920*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_inp.grad, torch.zeros_like(_inp))
7921*da0073e9SAndroid Build Coastguard Worker
7922*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
7923*da0073e9SAndroid Build Coastguard Worker                     "Scipy v1.0 and/or numpy not found")
7924*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
7925*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off()
7926*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off()
7927*da0073e9SAndroid Build Coastguard Worker    def test_affine_2d_rotate0(self, device):
7928*da0073e9SAndroid Build Coastguard Worker        # scipy before 1.0.0 do not support homogeneous coordinate
7929*da0073e9SAndroid Build Coastguard Worker        # scipy.ndimage.affine_transform, so we need to skip.
7930*da0073e9SAndroid Build Coastguard Worker        input_size = [1, 1, 3, 3]
7931*da0073e9SAndroid Build Coastguard Worker        input_ary = np.array(np.random.random(input_size), dtype=np.float32)
7932*da0073e9SAndroid Build Coastguard Worker        output_size = [1, 1, 5, 5]
7933*da0073e9SAndroid Build Coastguard Worker        angle_rad = 0.
7934*da0073e9SAndroid Build Coastguard Worker
7935*da0073e9SAndroid Build Coastguard Worker        transform_tensor, transform_ary, offset = \
7936*da0073e9SAndroid Build Coastguard Worker            _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
7937*da0073e9SAndroid Build Coastguard Worker
7938*da0073e9SAndroid Build Coastguard Worker        scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
7939*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0],
7940*da0073e9SAndroid Build Coastguard Worker            transform_ary,
7941*da0073e9SAndroid Build Coastguard Worker            offset=offset,
7942*da0073e9SAndroid Build Coastguard Worker            output_shape=output_size[2:],
7943*da0073e9SAndroid Build Coastguard Worker            order=1,
7944*da0073e9SAndroid Build Coastguard Worker            mode='nearest',
7945*da0073e9SAndroid Build Coastguard Worker            prefilter=False))
7946*da0073e9SAndroid Build Coastguard Worker
7947*da0073e9SAndroid Build Coastguard Worker        affine_tensor = torch.nn.functional.affine_grid(
7948*da0073e9SAndroid Build Coastguard Worker            transform_tensor,
7949*da0073e9SAndroid Build Coastguard Worker            torch.Size(output_size),
7950*da0073e9SAndroid Build Coastguard Worker            align_corners=True
7951*da0073e9SAndroid Build Coastguard Worker        )
7952*da0073e9SAndroid Build Coastguard Worker
7953*da0073e9SAndroid Build Coastguard Worker        gridsample_ary = torch.nn.functional.grid_sample(
7954*da0073e9SAndroid Build Coastguard Worker            torch.tensor(input_ary, device=device).to(device),
7955*da0073e9SAndroid Build Coastguard Worker            affine_tensor,
7956*da0073e9SAndroid Build Coastguard Worker            padding_mode='border',
7957*da0073e9SAndroid Build Coastguard Worker            align_corners=True
7958*da0073e9SAndroid Build Coastguard Worker        ).to('cpu')
7959*da0073e9SAndroid Build Coastguard Worker
7960*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scipy_ary.mean(), gridsample_ary.mean())
7961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
7962*da0073e9SAndroid Build Coastguard Worker
7963*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
7964*da0073e9SAndroid Build Coastguard Worker                     "Scipy v1.0 and/or numpy not found")
7965*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
7966*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.001)
7967*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.001)
7968*da0073e9SAndroid Build Coastguard Worker    def test_affine_2d_rotate90(self, device):
7969*da0073e9SAndroid Build Coastguard Worker        # scipy before 1.0.0 do not support homogeneous coordinate
7970*da0073e9SAndroid Build Coastguard Worker        # scipy.ndimage.affine_transform, so we need to skip.
7971*da0073e9SAndroid Build Coastguard Worker        for input_size2dsq, output_size2dsq in \
7972*da0073e9SAndroid Build Coastguard Worker                itertools.product(input_size2dsq_(), output_size2dsq_()):
7973*da0073e9SAndroid Build Coastguard Worker            input_size = input_size2dsq
7974*da0073e9SAndroid Build Coastguard Worker            input_ary = np.array(np.random.random(input_size), dtype=np.float32)
7975*da0073e9SAndroid Build Coastguard Worker            output_size = output_size2dsq
7976*da0073e9SAndroid Build Coastguard Worker            angle_rad = 0.25 * math.pi * 2
7977*da0073e9SAndroid Build Coastguard Worker
7978*da0073e9SAndroid Build Coastguard Worker            transform_tensor, transform_ary, offset = \
7979*da0073e9SAndroid Build Coastguard Worker                _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
7980*da0073e9SAndroid Build Coastguard Worker
7981*da0073e9SAndroid Build Coastguard Worker            scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
7982*da0073e9SAndroid Build Coastguard Worker                input_ary[0, 0],
7983*da0073e9SAndroid Build Coastguard Worker                transform_ary,
7984*da0073e9SAndroid Build Coastguard Worker                offset=offset,
7985*da0073e9SAndroid Build Coastguard Worker                output_shape=output_size[2:],
7986*da0073e9SAndroid Build Coastguard Worker                order=1,
7987*da0073e9SAndroid Build Coastguard Worker                mode='nearest',
7988*da0073e9SAndroid Build Coastguard Worker                prefilter=True))
7989*da0073e9SAndroid Build Coastguard Worker
7990*da0073e9SAndroid Build Coastguard Worker            if input_size2dsq == output_size2dsq:
7991*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(scipy_ary.mean(), input_ary.mean())
7992*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary[0, 0], input_ary[0, 0, 0, -1])
7993*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary[0, -1], input_ary[0, 0, -1, -1])
7994*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary[-1, -1], input_ary[0, 0, -1, 0])
7995*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary[-1, 0], input_ary[0, 0, 0, 0])
7996*da0073e9SAndroid Build Coastguard Worker
7997*da0073e9SAndroid Build Coastguard Worker            affine_tensor = torch.nn.functional.affine_grid(
7998*da0073e9SAndroid Build Coastguard Worker                transform_tensor,
7999*da0073e9SAndroid Build Coastguard Worker                torch.Size(output_size),
8000*da0073e9SAndroid Build Coastguard Worker                align_corners=True
8001*da0073e9SAndroid Build Coastguard Worker            )
8002*da0073e9SAndroid Build Coastguard Worker
8003*da0073e9SAndroid Build Coastguard Worker            gridsample_ary = torch.nn.functional.grid_sample(
8004*da0073e9SAndroid Build Coastguard Worker                torch.tensor(input_ary, device=device).to(device),
8005*da0073e9SAndroid Build Coastguard Worker                affine_tensor,
8006*da0073e9SAndroid Build Coastguard Worker                padding_mode='border',
8007*da0073e9SAndroid Build Coastguard Worker                align_corners=True
8008*da0073e9SAndroid Build Coastguard Worker            ).to('cpu')
8009*da0073e9SAndroid Build Coastguard Worker
8010*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary.mean(), gridsample_ary.mean())
8011*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8012*da0073e9SAndroid Build Coastguard Worker
8013*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
8014*da0073e9SAndroid Build Coastguard Worker                     "Scipy v1.0 and/or numpy not found")
8015*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
8016*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
8017*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
8018*da0073e9SAndroid Build Coastguard Worker    def test_affine_2d_rotate45(self, device):
8019*da0073e9SAndroid Build Coastguard Worker        # scipy before 1.0.0 do not support homogeneous coordinate
8020*da0073e9SAndroid Build Coastguard Worker        # scipy.ndimage.affine_transform, so we need to skip.
8021*da0073e9SAndroid Build Coastguard Worker        input_size = [1, 1, 3, 3]
8022*da0073e9SAndroid Build Coastguard Worker        input_ary = np.array(np.zeros(input_size), dtype=np.float32)
8023*da0073e9SAndroid Build Coastguard Worker        input_ary[0, 0, 0, :] = 0.5
8024*da0073e9SAndroid Build Coastguard Worker        input_ary[0, 0, 2, 2] = 1.0
8025*da0073e9SAndroid Build Coastguard Worker        output_size = [1, 1, 3, 3]
8026*da0073e9SAndroid Build Coastguard Worker        angle_rad = 0.125 * math.pi * 2
8027*da0073e9SAndroid Build Coastguard Worker
8028*da0073e9SAndroid Build Coastguard Worker        transform_tensor, transform_ary, offset = \
8029*da0073e9SAndroid Build Coastguard Worker            _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
8030*da0073e9SAndroid Build Coastguard Worker
8031*da0073e9SAndroid Build Coastguard Worker        scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
8032*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0],
8033*da0073e9SAndroid Build Coastguard Worker            transform_ary,
8034*da0073e9SAndroid Build Coastguard Worker            offset=offset,
8035*da0073e9SAndroid Build Coastguard Worker            output_shape=output_size[2:],
8036*da0073e9SAndroid Build Coastguard Worker            order=1,
8037*da0073e9SAndroid Build Coastguard Worker            mode='nearest',
8038*da0073e9SAndroid Build Coastguard Worker            prefilter=False))
8039*da0073e9SAndroid Build Coastguard Worker
8040*da0073e9SAndroid Build Coastguard Worker        affine_tensor = torch.nn.functional.affine_grid(
8041*da0073e9SAndroid Build Coastguard Worker            transform_tensor,
8042*da0073e9SAndroid Build Coastguard Worker            torch.Size(output_size),
8043*da0073e9SAndroid Build Coastguard Worker            align_corners=True
8044*da0073e9SAndroid Build Coastguard Worker        )
8045*da0073e9SAndroid Build Coastguard Worker
8046*da0073e9SAndroid Build Coastguard Worker        gridsample_ary = torch.nn.functional.grid_sample(
8047*da0073e9SAndroid Build Coastguard Worker            torch.tensor(input_ary, device=device).to(device),
8048*da0073e9SAndroid Build Coastguard Worker            affine_tensor,
8049*da0073e9SAndroid Build Coastguard Worker            padding_mode='border',
8050*da0073e9SAndroid Build Coastguard Worker            align_corners=True
8051*da0073e9SAndroid Build Coastguard Worker        ).to('cpu')
8052*da0073e9SAndroid Build Coastguard Worker
8053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8054*da0073e9SAndroid Build Coastguard Worker
8055*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8056*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("60GB", "cpu")
8057*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("16GB", "cuda")
8058*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool_large_tensor(self, device):
8059*da0073e9SAndroid Build Coastguard Worker        # test for https://github.com/pytorch/pytorch/issues/113833
8060*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(128, 256, 256, 256, dtype=torch.half, device=device, requires_grad=True)
8061*da0073e9SAndroid Build Coastguard Worker        a_cpu = a.detach().cpu().float()
8062*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.AvgPool2d(2)
8063*da0073e9SAndroid Build Coastguard Worker        o = m(a)
8064*da0073e9SAndroid Build Coastguard Worker        a_cpu.requires_grad = True
8065*da0073e9SAndroid Build Coastguard Worker        o.sum().backward()
8066*da0073e9SAndroid Build Coastguard Worker        o_cpu = m(a_cpu)
8067*da0073e9SAndroid Build Coastguard Worker        o_cpu.sum().backward()
8068*da0073e9SAndroid Build Coastguard Worker        # workaround for memory usage overhead of assertEqual
8069*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(a.grad.cpu(), a_cpu.grad.half()))
8070*da0073e9SAndroid Build Coastguard Worker
8071*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8072*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("48GB", "cpu")
8073*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("48GB", "cuda")
8074*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool_large_tensor2(self, device):
8075*da0073e9SAndroid Build Coastguard Worker        # test for https://github.com/pytorch/pytorch/issues/129785
8076*da0073e9SAndroid Build Coastguard Worker        out_size = [2048, 64, 104, 79]
8077*da0073e9SAndroid Build Coastguard Worker        size = [2048, 64, 209, 159]
8078*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(size, device=device, requires_grad=True, dtype=torch.float)
8079*da0073e9SAndroid Build Coastguard Worker        inp_cpu = inp.detach().cpu()
8080*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.AvgPool2d([2, 2], [2, 2], [0, 0], False, True, None)
8081*da0073e9SAndroid Build Coastguard Worker        o = m(inp)
8082*da0073e9SAndroid Build Coastguard Worker        inp_cpu.requires_grad = True
8083*da0073e9SAndroid Build Coastguard Worker        o.sum().backward()
8084*da0073e9SAndroid Build Coastguard Worker        o_cpu = m(inp_cpu)
8085*da0073e9SAndroid Build Coastguard Worker        o_cpu.sum().backward()
8086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o.shape, out_size)
8087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o_cpu.shape, out_size)
8088*da0073e9SAndroid Build Coastguard Worker        # reduce memory usage
8089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp.grad.sum(), inp_cpu.grad.sum())
8090*da0073e9SAndroid Build Coastguard Worker
8091*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
8092*da0073e9SAndroid Build Coastguard Worker                     "Scipy v1.0 and/or numpy not found")
8093*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098
8094*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
8095*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
8096*da0073e9SAndroid Build Coastguard Worker    def test_affine_2d_rotateRandom(self, device):
8097*da0073e9SAndroid Build Coastguard Worker        # scipy before 1.0.0 do not support homogeneous coordinate
8098*da0073e9SAndroid Build Coastguard Worker        # scipy.ndimage.affine_transform, so we need to skip.
8099*da0073e9SAndroid Build Coastguard Worker        for angle_rad, input_size2d, output_size2d in \
8100*da0073e9SAndroid Build Coastguard Worker                itertools.product(angle_rad_(), input_size2d_(), output_size2d_()):
8101*da0073e9SAndroid Build Coastguard Worker
8102*da0073e9SAndroid Build Coastguard Worker            input_size = input_size2d
8103*da0073e9SAndroid Build Coastguard Worker            input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3)
8104*da0073e9SAndroid Build Coastguard Worker            output_size = output_size2d
8105*da0073e9SAndroid Build Coastguard Worker
8106*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, 0, 0] = 2
8107*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, 0, -1] = 4
8108*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, -1, 0] = 6
8109*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, -1, -1] = 8
8110*da0073e9SAndroid Build Coastguard Worker
8111*da0073e9SAndroid Build Coastguard Worker            transform_tensor, transform_ary, grid_ary = \
8112*da0073e9SAndroid Build Coastguard Worker                _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad)
8113*da0073e9SAndroid Build Coastguard Worker
8114*da0073e9SAndroid Build Coastguard Worker            scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
8115*da0073e9SAndroid Build Coastguard Worker                input_ary[0, 0],
8116*da0073e9SAndroid Build Coastguard Worker                transform_ary,
8117*da0073e9SAndroid Build Coastguard Worker                output_shape=output_size[2:],
8118*da0073e9SAndroid Build Coastguard Worker                order=1,
8119*da0073e9SAndroid Build Coastguard Worker                mode='nearest',
8120*da0073e9SAndroid Build Coastguard Worker                prefilter=False))
8121*da0073e9SAndroid Build Coastguard Worker
8122*da0073e9SAndroid Build Coastguard Worker            affine_tensor = torch.nn.functional.affine_grid(
8123*da0073e9SAndroid Build Coastguard Worker                transform_tensor,
8124*da0073e9SAndroid Build Coastguard Worker                torch.Size(output_size),
8125*da0073e9SAndroid Build Coastguard Worker                align_corners=True
8126*da0073e9SAndroid Build Coastguard Worker            )
8127*da0073e9SAndroid Build Coastguard Worker
8128*da0073e9SAndroid Build Coastguard Worker            gridsample_ary = torch.nn.functional.grid_sample(
8129*da0073e9SAndroid Build Coastguard Worker                torch.tensor(input_ary, device=device).to(device),
8130*da0073e9SAndroid Build Coastguard Worker                affine_tensor,
8131*da0073e9SAndroid Build Coastguard Worker                padding_mode='border',
8132*da0073e9SAndroid Build Coastguard Worker                align_corners=True
8133*da0073e9SAndroid Build Coastguard Worker            ).to('cpu')
8134*da0073e9SAndroid Build Coastguard Worker
8135*da0073e9SAndroid Build Coastguard Worker            affine_tensor = affine_tensor.to('cpu')
8136*da0073e9SAndroid Build Coastguard Worker
8137*da0073e9SAndroid Build Coastguard Worker            for r in range(affine_tensor.size(1)):
8138*da0073e9SAndroid Build Coastguard Worker                for c in range(affine_tensor.size(2)):
8139*da0073e9SAndroid Build Coastguard Worker                    grid_out = np.dot(grid_ary, [r, c, 1])
8140*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(affine_tensor[0, r, c], grid_out[:2], exact_dtype=False)
8141*da0073e9SAndroid Build Coastguard Worker
8142*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8143*da0073e9SAndroid Build Coastguard Worker
8144*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
8145*da0073e9SAndroid Build Coastguard Worker                     "Scipy v1.0 and/or numpy not found")
8146*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764
8147*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
8148*da0073e9SAndroid Build Coastguard Worker    @bf32_on_and_off(0.005)
8149*da0073e9SAndroid Build Coastguard Worker    def test_affine_3d_rotateRandom(self, device):
8150*da0073e9SAndroid Build Coastguard Worker        # scipy before 1.0.0 do not support homogeneous coordinate
8151*da0073e9SAndroid Build Coastguard Worker        # scipy.ndimage.affine_transform, so we need to skip.
8152*da0073e9SAndroid Build Coastguard Worker        for angle_rad, axis_vector, input_size3d, output_size3d in \
8153*da0073e9SAndroid Build Coastguard Worker                itertools.product(angle_rad_(), axis_vector_(), input_size3d_(), output_size3d_()):
8154*da0073e9SAndroid Build Coastguard Worker            input_size = input_size3d
8155*da0073e9SAndroid Build Coastguard Worker            input_ary = np.array(np.random.random(input_size), dtype=np.float32)
8156*da0073e9SAndroid Build Coastguard Worker            output_size = output_size3d
8157*da0073e9SAndroid Build Coastguard Worker
8158*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, 0, 0, 0] = 2
8159*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, 0, 0, -1] = 3
8160*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, 0, -1, 0] = 4
8161*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, 0, -1, -1] = 5
8162*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, -1, 0, 0] = 6
8163*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, -1, 0, -1] = 7
8164*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, -1, -1, 0] = 8
8165*da0073e9SAndroid Build Coastguard Worker            input_ary[0, 0, -1, -1, -1] = 9
8166*da0073e9SAndroid Build Coastguard Worker
8167*da0073e9SAndroid Build Coastguard Worker            transform_tensor, transform_ary, grid_ary = \
8168*da0073e9SAndroid Build Coastguard Worker                _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
8169*da0073e9SAndroid Build Coastguard Worker
8170*da0073e9SAndroid Build Coastguard Worker            scipy_ary = torch.from_numpy(scipy.ndimage.affine_transform(
8171*da0073e9SAndroid Build Coastguard Worker                input_ary[0, 0],
8172*da0073e9SAndroid Build Coastguard Worker                transform_ary,
8173*da0073e9SAndroid Build Coastguard Worker                output_shape=output_size[2:],
8174*da0073e9SAndroid Build Coastguard Worker                order=1,
8175*da0073e9SAndroid Build Coastguard Worker                mode='nearest',
8176*da0073e9SAndroid Build Coastguard Worker                prefilter=False))
8177*da0073e9SAndroid Build Coastguard Worker
8178*da0073e9SAndroid Build Coastguard Worker            affine_tensor = torch.nn.functional.affine_grid(
8179*da0073e9SAndroid Build Coastguard Worker                transform_tensor,
8180*da0073e9SAndroid Build Coastguard Worker                torch.Size(output_size),
8181*da0073e9SAndroid Build Coastguard Worker                align_corners=True
8182*da0073e9SAndroid Build Coastguard Worker            )
8183*da0073e9SAndroid Build Coastguard Worker
8184*da0073e9SAndroid Build Coastguard Worker            gridsample_ary = torch.nn.functional.grid_sample(
8185*da0073e9SAndroid Build Coastguard Worker                torch.tensor(input_ary, device=device).to(device),
8186*da0073e9SAndroid Build Coastguard Worker                affine_tensor,
8187*da0073e9SAndroid Build Coastguard Worker                padding_mode='border',
8188*da0073e9SAndroid Build Coastguard Worker                align_corners=True
8189*da0073e9SAndroid Build Coastguard Worker            ).to('cpu')
8190*da0073e9SAndroid Build Coastguard Worker
8191*da0073e9SAndroid Build Coastguard Worker            affine_tensor = affine_tensor.to('cpu')
8192*da0073e9SAndroid Build Coastguard Worker
8193*da0073e9SAndroid Build Coastguard Worker            for i in range(affine_tensor.size(1)):
8194*da0073e9SAndroid Build Coastguard Worker                for r in range(affine_tensor.size(2)):
8195*da0073e9SAndroid Build Coastguard Worker                    for c in range(affine_tensor.size(3)):
8196*da0073e9SAndroid Build Coastguard Worker                        grid_out = np.dot(grid_ary, [i, r, c, 1])
8197*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3], exact_dtype=False)
8198*da0073e9SAndroid Build Coastguard Worker
8199*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
8200*da0073e9SAndroid Build Coastguard Worker
8201*da0073e9SAndroid Build Coastguard Worker
8202*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8203*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.half)
8204*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_large_batch(self, device, dtype):
8205*da0073e9SAndroid Build Coastguard Worker        bn = nn.BatchNorm2d(1).to(device, dtype)
8206*da0073e9SAndroid Build Coastguard Worker        data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype)
8207*da0073e9SAndroid Build Coastguard Worker        out = bn(data).sum().backward()
8208*da0073e9SAndroid Build Coastguard Worker
8209*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128)
8210*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.float, torch.half, torch.complex64)
8211*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128)
8212*da0073e9SAndroid Build Coastguard Worker    def test_conv_empty_input(self, device, dtype):
8213*da0073e9SAndroid Build Coastguard Worker        def help(input, conv, memory_format):
8214*da0073e9SAndroid Build Coastguard Worker            ref_out = conv(input)
8215*da0073e9SAndroid Build Coastguard Worker            conv_cl = conv.to(memory_format=memory_format)
8216*da0073e9SAndroid Build Coastguard Worker            out_cl = conv_cl(input)
8217*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref_out, out_cl)
8218*da0073e9SAndroid Build Coastguard Worker            input_cl = input.to(memory_format=memory_format)
8219*da0073e9SAndroid Build Coastguard Worker            out_cl2 = conv(input_cl)
8220*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_cl, out_cl2)
8221*da0073e9SAndroid Build Coastguard Worker            out_cl3 = conv_cl(input_cl)
8222*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_cl, out_cl3)
8223*da0073e9SAndroid Build Coastguard Worker
8224*da0073e9SAndroid Build Coastguard Worker        # channels_last case
8225*da0073e9SAndroid Build Coastguard Worker        input2d = torch.randn((0, 4, 20, 20)).to(device=device, dtype=dtype)
8226*da0073e9SAndroid Build Coastguard Worker        conv2d = torch.nn.Conv2d(4, 4, 3, 1).to(device=device, dtype=dtype)
8227*da0073e9SAndroid Build Coastguard Worker        help(input2d, conv2d, torch.channels_last)
8228*da0073e9SAndroid Build Coastguard Worker        # channels_last_3d case
8229*da0073e9SAndroid Build Coastguard Worker        input3d = torch.randn((0, 4, 20, 20, 20)).to(device=device, dtype=dtype)
8230*da0073e9SAndroid Build Coastguard Worker        conv3d = torch.nn.Conv3d(4, 4, 3, 1).to(device=device, dtype=dtype)
8231*da0073e9SAndroid Build Coastguard Worker        help(input3d, conv3d, torch.channels_last_3d)
8232*da0073e9SAndroid Build Coastguard Worker        # non-contiguous case
8233*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(4, 8, 3, 3)[:, ::2, :, :].to(device=device, dtype=dtype)
8234*da0073e9SAndroid Build Coastguard Worker        bias = torch.rand(4).to(device=device, dtype=dtype)
8235*da0073e9SAndroid Build Coastguard Worker        out = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1)
8236*da0073e9SAndroid Build Coastguard Worker        weight = weight.contiguous()
8237*da0073e9SAndroid Build Coastguard Worker        out_ref = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1)
8238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
8239*da0073e9SAndroid Build Coastguard Worker        # sigfpe reported in https://github.com/pytorch/pytorch/issues/94125
8240*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
8241*da0073e9SAndroid Build Coastguard Worker            inp = torch.empty([1, 1, 1, 0], dtype=dtype, device=device)
8242*da0073e9SAndroid Build Coastguard Worker            weight = torch.empty([1, 0, 1], dtype=dtype, device=device)
8243*da0073e9SAndroid Build Coastguard Worker            torch._C._nn.slow_conv3d(inp, weight, 1)
8244*da0073e9SAndroid Build Coastguard Worker
8245*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, re.escape("2D kernel_size expected")):
8246*da0073e9SAndroid Build Coastguard Worker            torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[], padding=[1, 1], stride=[1, 1],
8247*da0073e9SAndroid Build Coastguard Worker                                     weight=torch.rand([1, 1]))
8248*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, re.escape("2D stride expected")):
8249*da0073e9SAndroid Build Coastguard Worker            torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[1, 1], stride=[],
8250*da0073e9SAndroid Build Coastguard Worker                                     weight=torch.rand([1, 1]))
8251*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, re.escape("2D padding expected")):
8252*da0073e9SAndroid Build Coastguard Worker            torch._C._nn.thnn_conv2d(torch.rand([1, 1, 1, 1]), kernel_size=[1, 1], padding=[], stride=[1, 1],
8253*da0073e9SAndroid Build Coastguard Worker                                     weight=torch.rand([1, 1]))
8254*da0073e9SAndroid Build Coastguard Worker
8255*da0073e9SAndroid Build Coastguard Worker    def test_InstanceNorm1d_general(self, device):
8256*da0073e9SAndroid Build Coastguard Worker        b = random.randint(3, 5)
8257*da0073e9SAndroid Build Coastguard Worker        c = random.randint(3, 5)
8258*da0073e9SAndroid Build Coastguard Worker        d = random.randint(8, 10)
8259*da0073e9SAndroid Build Coastguard Worker
8260*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(b, c, d)
8261*da0073e9SAndroid Build Coastguard Worker        self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)
8262*da0073e9SAndroid Build Coastguard Worker
8263*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8264*da0073e9SAndroid Build Coastguard Worker            self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input, device)
8265*da0073e9SAndroid Build Coastguard Worker
8266*da0073e9SAndroid Build Coastguard Worker    def test_InstanceNorm2d_general(self, device):
8267*da0073e9SAndroid Build Coastguard Worker        b = random.randint(3, 5)
8268*da0073e9SAndroid Build Coastguard Worker        c = random.randint(3, 5)
8269*da0073e9SAndroid Build Coastguard Worker        w = random.randint(3, 6)
8270*da0073e9SAndroid Build Coastguard Worker        h = random.randint(6, 8)
8271*da0073e9SAndroid Build Coastguard Worker
8272*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(b, c, h, w)
8273*da0073e9SAndroid Build Coastguard Worker        self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)
8274*da0073e9SAndroid Build Coastguard Worker
8275*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8276*da0073e9SAndroid Build Coastguard Worker            self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input, device)
8277*da0073e9SAndroid Build Coastguard Worker
8278*da0073e9SAndroid Build Coastguard Worker    def test_InstanceNorm3d_general(self, device):
8279*da0073e9SAndroid Build Coastguard Worker        b = random.randint(3, 5)
8280*da0073e9SAndroid Build Coastguard Worker        c = random.randint(3, 5)
8281*da0073e9SAndroid Build Coastguard Worker        w = random.randint(2, 5)
8282*da0073e9SAndroid Build Coastguard Worker        h = random.randint(2, 5)
8283*da0073e9SAndroid Build Coastguard Worker        d = random.randint(2, 5)
8284*da0073e9SAndroid Build Coastguard Worker
8285*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(b, c, h, w, d)
8286*da0073e9SAndroid Build Coastguard Worker        self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)
8287*da0073e9SAndroid Build Coastguard Worker
8288*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8289*da0073e9SAndroid Build Coastguard Worker            self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device)
8290*da0073e9SAndroid Build Coastguard Worker
8291*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("instance_norm_cls", [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d], name_fn=lambda c: c.__name__)
8292*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("no_batch_dim", [True, False])
8293*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("affine", [True, False])
8294*da0073e9SAndroid Build Coastguard Worker    def test_instancenorm_raises_error_if_input_channels_is_not_num_features(self, device, instance_norm_cls, no_batch_dim, affine):
8295*da0073e9SAndroid Build Coastguard Worker        inst_norm = instance_norm_cls(4, affine=affine)
8296*da0073e9SAndroid Build Coastguard Worker        size = [2] * inst_norm._get_no_batch_dim()
8297*da0073e9SAndroid Build Coastguard Worker        if not no_batch_dim:
8298*da0073e9SAndroid Build Coastguard Worker            size = [3] + size
8299*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(size)
8300*da0073e9SAndroid Build Coastguard Worker        if affine:
8301*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "expected input's size at dim="):
8302*da0073e9SAndroid Build Coastguard Worker                inst_norm(t)
8303*da0073e9SAndroid Build Coastguard Worker        else:
8304*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
8305*da0073e9SAndroid Build Coastguard Worker                inst_norm(t)
8306*da0073e9SAndroid Build Coastguard Worker            self.assertIn("which is not used because affine=False", str(w[0].message))
8307*da0073e9SAndroid Build Coastguard Worker
8308*da0073e9SAndroid Build Coastguard Worker    def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, device):
8309*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10)[None, :, None]
8310*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
8311*da0073e9SAndroid Build Coastguard Worker            torch.nn.InstanceNorm1d(10)(x).to(device)
8312*da0073e9SAndroid Build Coastguard Worker
8313*da0073e9SAndroid Build Coastguard Worker    def test_instancenorm_raises_error_for_single_spatial_element_during_training(self, device):
8314*da0073e9SAndroid Build Coastguard Worker        BATCH_SIZE = 10
8315*da0073e9SAndroid Build Coastguard Worker        NUM_CHANNELS = 3
8316*da0073e9SAndroid Build Coastguard Worker        norms = [torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d]
8317*da0073e9SAndroid Build Coastguard Worker        for i, norm in enumerate(norms):
8318*da0073e9SAndroid Build Coastguard Worker            m = norm(NUM_CHANNELS, track_running_stats=True)
8319*da0073e9SAndroid Build Coastguard Worker            m.to(device)
8320*da0073e9SAndroid Build Coastguard Worker
8321*da0073e9SAndroid Build Coastguard Worker            # Create an appropriately-sized input with a single spatial element.
8322*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(BATCH_SIZE, NUM_CHANNELS, *[1 for _ in range(i + 1)],
8323*da0073e9SAndroid Build Coastguard Worker                                device=device)
8324*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(ValueError):
8325*da0073e9SAndroid Build Coastguard Worker                m(input)
8326*da0073e9SAndroid Build Coastguard Worker
8327*da0073e9SAndroid Build Coastguard Worker            # Single spatial element should be fine in eval.
8328*da0073e9SAndroid Build Coastguard Worker            m.eval()
8329*da0073e9SAndroid Build Coastguard Worker            m(input)
8330*da0073e9SAndroid Build Coastguard Worker
8331*da0073e9SAndroid Build Coastguard Worker    def test_LayerNorm_general(self, device):
8332*da0073e9SAndroid Build Coastguard Worker        self._test_LayerNorm_general(device)
8333*da0073e9SAndroid Build Coastguard Worker
8334*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' or self.device_type == 'cpu':
8335*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.half, torch.bfloat16]:
8336*da0073e9SAndroid Build Coastguard Worker                self._test_LayerNorm_general(device, dtype=dtype)
8337*da0073e9SAndroid Build Coastguard Worker
8338*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8339*da0073e9SAndroid Build Coastguard Worker            self._test_LayerNorm_cuda_half(device)
8340*da0073e9SAndroid Build Coastguard Worker
8341*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cpu':
8342*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.half, torch.bfloat16]:
8343*da0073e9SAndroid Build Coastguard Worker                self._test_LayerNorm_cpu_mixed_dtype(device, dtype=dtype)
8344*da0073e9SAndroid Build Coastguard Worker
8345*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8346*da0073e9SAndroid Build Coastguard Worker    def test_LayerNorm_numeric(self, device):
8347*da0073e9SAndroid Build Coastguard Worker        def layer_norm_ref(X, gamma, beta, normalized_shape, eps):
8348*da0073e9SAndroid Build Coastguard Worker            feature_size = np.prod(normalized_shape)
8349*da0073e9SAndroid Build Coastguard Worker            X_view = X.view(-1, feature_size)
8350*da0073e9SAndroid Build Coastguard Worker            mean = X_view.mean(dim=-1, keepdim=True)
8351*da0073e9SAndroid Build Coastguard Worker            var = X_view.var(dim=-1, unbiased=False, keepdim=True)
8352*da0073e9SAndroid Build Coastguard Worker            Y = (X_view - mean) / torch.sqrt(var + eps)
8353*da0073e9SAndroid Build Coastguard Worker            Y = Y * gamma.view(-1) + beta.view(-1)
8354*da0073e9SAndroid Build Coastguard Worker            return Y.view(*X.size())
8355*da0073e9SAndroid Build Coastguard Worker
8356*da0073e9SAndroid Build Coastguard Worker        normalized_shape = [256, 256, 144]
8357*da0073e9SAndroid Build Coastguard Worker        layer_norm = nn.LayerNorm(normalized_shape).float().to(device)
8358*da0073e9SAndroid Build Coastguard Worker        X = torch.rand(2, *normalized_shape, dtype=torch.float32,
8359*da0073e9SAndroid Build Coastguard Worker                       device=device)
8360*da0073e9SAndroid Build Coastguard Worker
8361*da0073e9SAndroid Build Coastguard Worker        Y = layer_norm(X)
8362*da0073e9SAndroid Build Coastguard Worker        Y_ref = layer_norm_ref(X, layer_norm.weight.data, layer_norm.bias.data,
8363*da0073e9SAndroid Build Coastguard Worker                               normalized_shape, layer_norm.eps)
8364*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5)
8365*da0073e9SAndroid Build Coastguard Worker
8366*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8367*da0073e9SAndroid Build Coastguard Worker            layer_norm.cpu()
8368*da0073e9SAndroid Build Coastguard Worker            Y_cpu = layer_norm(X.cpu())
8369*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
8370*da0073e9SAndroid Build Coastguard Worker
8371*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
8372*da0073e9SAndroid Build Coastguard Worker    def test_glu_bfloat16(self, device):
8373*da0073e9SAndroid Build Coastguard Worker        def test_dtype(fn, input, dtype):
8374*da0073e9SAndroid Build Coastguard Worker            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
8375*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().float().requires_grad_(True)
8376*da0073e9SAndroid Build Coastguard Worker            out = fn(input)
8377*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
8378*da0073e9SAndroid Build Coastguard Worker            out2 = fn(input2)
8379*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
8380*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.dtype, dtype)
8381*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.dtype, dtype)
8382*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2, exact_dtype=False)
8383*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input2.grad, atol=1e-2, rtol=0, exact_dtype=False)
8384*da0073e9SAndroid Build Coastguard Worker
8385*da0073e9SAndroid Build Coastguard Worker        def func(device):
8386*da0073e9SAndroid Build Coastguard Worker            return torch.nn.GLU(dim=-1).to(device)
8387*da0073e9SAndroid Build Coastguard Worker
8388*da0073e9SAndroid Build Coastguard Worker        shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 256, 256]]
8389*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
8390*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device)
8391*da0073e9SAndroid Build Coastguard Worker            test_dtype(func(device), x, torch.bfloat16)
8392*da0073e9SAndroid Build Coastguard Worker
8393*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8394*da0073e9SAndroid Build Coastguard Worker    def test_GroupNorm_general(self, device):
8395*da0073e9SAndroid Build Coastguard Worker        self._test_GroupNorm_general(device)
8396*da0073e9SAndroid Build Coastguard Worker
8397*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8398*da0073e9SAndroid Build Coastguard Worker            self._test_GroupNorm_cuda_half()
8399*da0073e9SAndroid Build Coastguard Worker
8400*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cpu':
8401*da0073e9SAndroid Build Coastguard Worker            self._test_GroupNorm_cpu_mixed_dtype()
8402*da0073e9SAndroid Build Coastguard Worker
8403*da0073e9SAndroid Build Coastguard Worker    def test_GroupNorm_raises_error_if_one_value_per_group(self, device):
8404*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10)[None, :, None]
8405*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
8406*da0073e9SAndroid Build Coastguard Worker            torch.nn.GroupNorm(10, 10)(x).to(device)
8407*da0073e9SAndroid Build Coastguard Worker
8408*da0073e9SAndroid Build Coastguard Worker    def test_GroupNorm_empty(self, device):
8409*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.GroupNorm(2, 4).to(device)
8410*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 4, 2, 2, device=device)
8411*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp)
8412*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
8413*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
8414*da0073e9SAndroid Build Coastguard Worker                _test_module_empty_input(self, mod, inp)
8415*da0073e9SAndroid Build Coastguard Worker
8416*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
8417*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
8418*da0073e9SAndroid Build Coastguard Worker    def test_groupnorm_nhwc(self, device, dtype):
8419*da0073e9SAndroid Build Coastguard Worker        def helper(self, size, groups, memory_format, is_mixed):
8420*da0073e9SAndroid Build Coastguard Worker            channels = size[1]
8421*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
8422*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=memory_format)
8423*da0073e9SAndroid Build Coastguard Worker            input.retain_grad()
8424*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn(size, dtype=dtype, device=device)
8425*da0073e9SAndroid Build Coastguard Worker            grad = grad.contiguous(memory_format=memory_format)
8426*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16 and is_mixed:
8427*da0073e9SAndroid Build Coastguard Worker                gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
8428*da0073e9SAndroid Build Coastguard Worker            else:
8429*da0073e9SAndroid Build Coastguard Worker                gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
8430*da0073e9SAndroid Build Coastguard Worker            gn.weight.data.uniform_()
8431*da0073e9SAndroid Build Coastguard Worker            gn.bias.data.uniform_()
8432*da0073e9SAndroid Build Coastguard Worker
8433*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous(memory_format=torch.contiguous_format).requires_grad_(True)
8434*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous(memory_format=torch.contiguous_format)
8435*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16 and is_mixed:
8436*da0073e9SAndroid Build Coastguard Worker                ref_gn = nn.GroupNorm(groups, channels).to(device).to(torch.float)
8437*da0073e9SAndroid Build Coastguard Worker            else:
8438*da0073e9SAndroid Build Coastguard Worker                ref_gn = nn.GroupNorm(groups, channels).to(device).to(dtype)
8439*da0073e9SAndroid Build Coastguard Worker            ref_gn.load_state_dict(gn.state_dict())
8440*da0073e9SAndroid Build Coastguard Worker            out = gn(input)
8441*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
8442*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_gn(ref_input)
8443*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
8444*da0073e9SAndroid Build Coastguard Worker
8445*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=memory_format))
8446*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous(memory_format=torch.contiguous_format))
8447*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
8448*da0073e9SAndroid Build Coastguard Worker            # parameters in bfloat16/Half is not recommended
8449*da0073e9SAndroid Build Coastguard Worker            atol = 5e-4
8450*da0073e9SAndroid Build Coastguard Worker            rtol = 8e-3
8451*da0073e9SAndroid Build Coastguard Worker
8452*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gn.weight.grad, ref_gn.weight.grad, atol=atol, rtol=rtol)
8453*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gn.bias.grad, ref_gn.bias.grad, atol=atol, rtol=rtol)
8454*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad, atol=atol, rtol=rtol)
8455*da0073e9SAndroid Build Coastguard Worker
8456*da0073e9SAndroid Build Coastguard Worker        for is_mixed in [True, False]:
8457*da0073e9SAndroid Build Coastguard Worker            helper(self, (4, 8, 10, 10), 4, torch.channels_last, is_mixed)
8458*da0073e9SAndroid Build Coastguard Worker            helper(self, (2, 30, 9, 9), 3, torch.channels_last, is_mixed)
8459*da0073e9SAndroid Build Coastguard Worker            helper(self, (4, 8, 40, 40), 4, torch.channels_last, is_mixed)
8460*da0073e9SAndroid Build Coastguard Worker            helper(self, (4, 40, 40, 40), 2, torch.channels_last, is_mixed)
8461*da0073e9SAndroid Build Coastguard Worker            helper(self, (2, 30, 50, 50), 3, torch.channels_last, is_mixed)
8462*da0073e9SAndroid Build Coastguard Worker            helper(self, (2, 60, 50, 50), 3, torch.channels_last, is_mixed)
8463*da0073e9SAndroid Build Coastguard Worker            helper(self, (2, 9, 7, 11, 15), 3, torch.channels_last_3d, is_mixed)
8464*da0073e9SAndroid Build Coastguard Worker            helper(self, (2, 9, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
8465*da0073e9SAndroid Build Coastguard Worker            helper(self, (2, 60, 7, 200, 15), 3, torch.channels_last_3d, is_mixed)
8466*da0073e9SAndroid Build Coastguard Worker
8467*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8468*da0073e9SAndroid Build Coastguard Worker    def test_GroupNorm_memory_format(self, device):
8469*da0073e9SAndroid Build Coastguard Worker        # Tests for regression reported in https://github.com/pytorch/pytorch/issues/92166
8470*da0073e9SAndroid Build Coastguard Worker
8471*da0073e9SAndroid Build Coastguard Worker        def helper(input_format, grad_format, B=2, C=4, W=4, H=4):
8472*da0073e9SAndroid Build Coastguard Worker            import copy
8473*da0073e9SAndroid Build Coastguard Worker            net_orig = torch.nn.GroupNorm(B, C).to(device=device)
8474*da0073e9SAndroid Build Coastguard Worker            net = copy.deepcopy(net_orig)
8475*da0073e9SAndroid Build Coastguard Worker            x_orig = torch.rand(B, C, W, H, device=device, requires_grad=True)
8476*da0073e9SAndroid Build Coastguard Worker            grad_orig = torch.rand(B, C, W, H, device=device)
8477*da0073e9SAndroid Build Coastguard Worker            x = x_orig.clone().detach().to(memory_format=input_format).requires_grad_(True)
8478*da0073e9SAndroid Build Coastguard Worker            grad = grad_orig.detach().to(memory_format=grad_format)
8479*da0073e9SAndroid Build Coastguard Worker
8480*da0073e9SAndroid Build Coastguard Worker            y = net(x)
8481*da0073e9SAndroid Build Coastguard Worker            y.backward(grad)
8482*da0073e9SAndroid Build Coastguard Worker
8483*da0073e9SAndroid Build Coastguard Worker            y_orig = net_orig(x_orig)
8484*da0073e9SAndroid Build Coastguard Worker            y_orig.backward(grad_orig)
8485*da0073e9SAndroid Build Coastguard Worker
8486*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_orig)
8487*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, x_orig.grad)
8488*da0073e9SAndroid Build Coastguard Worker
8489*da0073e9SAndroid Build Coastguard Worker        for input_format in [torch.contiguous_format, torch.channels_last]:
8490*da0073e9SAndroid Build Coastguard Worker            for grad_format in [torch.contiguous_format, torch.channels_last]:
8491*da0073e9SAndroid Build Coastguard Worker                helper(input_format, grad_format)
8492*da0073e9SAndroid Build Coastguard Worker
8493*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8494*da0073e9SAndroid Build Coastguard Worker    def test_GroupNorm_numeric(self, device):
8495*da0073e9SAndroid Build Coastguard Worker        def group_norm_ref(X, gamma, beta, groups, channels, eps):
8496*da0073e9SAndroid Build Coastguard Worker            batch_size = X.size()[0]
8497*da0073e9SAndroid Build Coastguard Worker            X_view = X.view(batch_size, groups, -1)
8498*da0073e9SAndroid Build Coastguard Worker            mean = X_view.mean(dim=-1, keepdim=True)
8499*da0073e9SAndroid Build Coastguard Worker            var = X_view.var(dim=-1, unbiased=False, keepdim=True)
8500*da0073e9SAndroid Build Coastguard Worker            Y = ((X_view - mean) / torch.sqrt(var + eps)).view(
8501*da0073e9SAndroid Build Coastguard Worker                batch_size, channels, -1)
8502*da0073e9SAndroid Build Coastguard Worker            Y = Y * gamma.view(channels, 1) + beta.view(channels, 1)
8503*da0073e9SAndroid Build Coastguard Worker            return Y.view(*X.size())
8504*da0073e9SAndroid Build Coastguard Worker
8505*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
8506*da0073e9SAndroid Build Coastguard Worker        groups = 2
8507*da0073e9SAndroid Build Coastguard Worker        channels = 8
8508*da0073e9SAndroid Build Coastguard Worker        group_norm = nn.GroupNorm(groups, channels).float().to(device)
8509*da0073e9SAndroid Build Coastguard Worker        X = torch.rand(batch_size, channels, 256, 256, 72,
8510*da0073e9SAndroid Build Coastguard Worker                       dtype=torch.float32, device=device)
8511*da0073e9SAndroid Build Coastguard Worker
8512*da0073e9SAndroid Build Coastguard Worker        Y = group_norm(X)
8513*da0073e9SAndroid Build Coastguard Worker        Y_ref = group_norm_ref(
8514*da0073e9SAndroid Build Coastguard Worker            X, group_norm.weight.data, group_norm.bias.data, groups,
8515*da0073e9SAndroid Build Coastguard Worker            channels, group_norm.eps)
8516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Y, Y_ref, rtol=0, atol=1e-5)
8517*da0073e9SAndroid Build Coastguard Worker
8518*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
8519*da0073e9SAndroid Build Coastguard Worker            group_norm.cpu()
8520*da0073e9SAndroid Build Coastguard Worker            Y_cpu = group_norm(X.cpu())
8521*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5)
8522*da0073e9SAndroid Build Coastguard Worker
8523*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8524*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.complex128)
8525*da0073e9SAndroid Build Coastguard Worker    def test_pad(self, device, dtype):
8526*da0073e9SAndroid Build Coastguard Worker        # Assert assertion errors are raised for invalid circular padding values
8527*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True)
8528*da0073e9SAndroid Build Coastguard Worker        # Should raise error when trying to wrap around more than once
8529*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (5, 4), mode='circular'))
8530*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (3, 6), mode='circular'))
8531*da0073e9SAndroid Build Coastguard Worker        # Should raise error when negative padding results in negative output shape
8532*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (-3, -2), mode='circular'))
8533*da0073e9SAndroid Build Coastguard Worker
8534*da0073e9SAndroid Build Coastguard Worker        # assert that relfection padding errors when pad >= input size
8535*da0073e9SAndroid Build Coastguard Worker        expected_err_msg = r"Padding size should be less than the corresponding input dimension"
8536*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(1, 1, 2, 3, device=device, dtype=dtype)
8537*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, expected_err_msg,
8538*da0073e9SAndroid Build Coastguard Worker                               lambda: F.pad(inputs, (1, 1, 3, 0), mode='reflect'))
8539*da0073e9SAndroid Build Coastguard Worker        inputs = torch.randn(1, 1, 2, device=device, dtype=dtype)
8540*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, expected_err_msg,
8541*da0073e9SAndroid Build Coastguard Worker                               lambda: F.pad(inputs, (2, 1), mode='reflect'))
8542*da0073e9SAndroid Build Coastguard Worker
8543*da0073e9SAndroid Build Coastguard Worker        inputs = torch.rand(1, 3, 4, 4, device=device, dtype=dtype)
8544*da0073e9SAndroid Build Coastguard Worker        # assert that pad doesn't return a view into the input tensor
8545*da0073e9SAndroid Build Coastguard Worker        for mode in 'constant', 'reflect', 'replicate', 'circular':
8546*da0073e9SAndroid Build Coastguard Worker            out = F.pad(inputs, (0, 0, 0, 0), mode=mode)
8547*da0073e9SAndroid Build Coastguard Worker            out.fill_(4)
8548*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.all(torch.abs(inputs) < 2))
8549*da0073e9SAndroid Build Coastguard Worker
8550*da0073e9SAndroid Build Coastguard Worker            out = F.pad(inputs, (0, 0, -1, -1), mode=mode)
8551*da0073e9SAndroid Build Coastguard Worker            out.fill_(4)
8552*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.all(torch.abs(inputs) < 2))
8553*da0073e9SAndroid Build Coastguard Worker
8554*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8555*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float64, torch.complex128)
8556*da0073e9SAndroid Build Coastguard Worker    def test_ReplicationPad_empty(self, device, dtype):
8557*da0073e9SAndroid Build Coastguard Worker        for mod, inp in [
8558*da0073e9SAndroid Build Coastguard Worker                (torch.nn.ReplicationPad1d(3), torch.randn(0, 3, 10, device=device, dtype=dtype)),
8559*da0073e9SAndroid Build Coastguard Worker                (torch.nn.ReplicationPad2d(3), torch.randn(0, 3, 10, 10, device=device, dtype=dtype)),
8560*da0073e9SAndroid Build Coastguard Worker                (torch.nn.ReplicationPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]:
8561*da0073e9SAndroid Build Coastguard Worker            _test_module_empty_input(self, mod, inp, check_size=False)
8562*da0073e9SAndroid Build Coastguard Worker
8563*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Expected 2D or 3D'):
8564*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.ReplicationPad1d(2)
8565*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(3, 0, 10, device=device, dtype=dtype)
8566*da0073e9SAndroid Build Coastguard Worker            mod(inp)
8567*da0073e9SAndroid Build Coastguard Worker
8568*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'):
8569*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.ReplicationPad2d((2, 2, 2, 2))
8570*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(43, 0, 10, 10, device=device, dtype=dtype)
8571*da0073e9SAndroid Build Coastguard Worker            mod(inp)
8572*da0073e9SAndroid Build Coastguard Worker
8573*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Expected 4D or 5D'):
8574*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.ReplicationPad3d((2, 2, 2, 2, 2, 2))
8575*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
8576*da0073e9SAndroid Build Coastguard Worker            mod(inp)
8577*da0073e9SAndroid Build Coastguard Worker
8578*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 2'):
8579*da0073e9SAndroid Build Coastguard Worker            torch._C._nn.replication_pad1d(torch.randn([2]), padding=[])
8580*da0073e9SAndroid Build Coastguard Worker
8581*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 4'):
8582*da0073e9SAndroid Build Coastguard Worker            torch._C._nn.replication_pad2d(torch.randn([2]), padding=[])
8583*da0073e9SAndroid Build Coastguard Worker
8584*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'):
8585*da0073e9SAndroid Build Coastguard Worker            torch._C._nn.replication_pad3d(torch.randn([2]), padding=[])
8586*da0073e9SAndroid Build Coastguard Worker
8587*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TODO(hvaara): Investigate as possible bug.
8588*da0073e9SAndroid Build Coastguard Worker    def test_ReplicationPad1d_large(self, device):
8589*da0073e9SAndroid Build Coastguard Worker        shapes = ([2, 65736, 4], [65736, 2, 4])
8590*da0073e9SAndroid Build Coastguard Worker        pl, pr = 3, 4
8591*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
8592*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device, requires_grad=True)
8593*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.ReplicationPad1d((pl, pr))
8594*da0073e9SAndroid Build Coastguard Worker
8595*da0073e9SAndroid Build Coastguard Worker            # forward
8596*da0073e9SAndroid Build Coastguard Worker            out = model(x)
8597*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[:, :, pl : -pr], x)
8598*da0073e9SAndroid Build Coastguard Worker
8599*da0073e9SAndroid Build Coastguard Worker            left_padding = out[:, :, : pl]
8600*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(left_padding, x[:, :, :1].expand_as(left_padding))
8601*da0073e9SAndroid Build Coastguard Worker            right_padding = out[:, :, -pr :]
8602*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(right_padding, x[:, :, -1:].expand_as(right_padding))
8603*da0073e9SAndroid Build Coastguard Worker
8604*da0073e9SAndroid Build Coastguard Worker            # backward
8605*da0073e9SAndroid Build Coastguard Worker            g = torch.randn_like(out)
8606*da0073e9SAndroid Build Coastguard Worker            out.backward(g)
8607*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 1 : -1], g[:, :, pl + 1 : -pr - 1])
8608*da0073e9SAndroid Build Coastguard Worker
8609*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1))
8610*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1))
8611*da0073e9SAndroid Build Coastguard Worker
8612*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TODO(hvaara): Investigate as possible bug.
8613*da0073e9SAndroid Build Coastguard Worker    def test_ReplicationPad2d_large(self, device):
8614*da0073e9SAndroid Build Coastguard Worker        shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4])
8615*da0073e9SAndroid Build Coastguard Worker        pl, pr, pt, pb = 3, 4, 5, 6
8616*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
8617*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device, requires_grad=True)
8618*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.ReplicationPad2d((pl, pr, pt, pb))
8619*da0073e9SAndroid Build Coastguard Worker
8620*da0073e9SAndroid Build Coastguard Worker            # forward center, edge
8621*da0073e9SAndroid Build Coastguard Worker            out = model(x)
8622*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[:, :, pt : -pb, pl : -pr], x)
8623*da0073e9SAndroid Build Coastguard Worker
8624*da0073e9SAndroid Build Coastguard Worker            left_padding = out[:, :, pt : -pb, : pl]
8625*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(left_padding, x[:, :, :, :1].expand_as(left_padding))
8626*da0073e9SAndroid Build Coastguard Worker            right_padding = out[:, :, pt : -pb, -pr :]
8627*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(right_padding, x[:, :, :, -1:].expand_as(right_padding))
8628*da0073e9SAndroid Build Coastguard Worker            top_padding = out[:, :, : pt, pl : -pr]
8629*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(top_padding, x[:, :, :1, :].expand_as(top_padding))
8630*da0073e9SAndroid Build Coastguard Worker            bottom_padding = out[:, :, -pb : , pl : -pr]
8631*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bottom_padding, x[:, :, -1:, :].expand_as(bottom_padding))
8632*da0073e9SAndroid Build Coastguard Worker
8633*da0073e9SAndroid Build Coastguard Worker            # forward corner
8634*da0073e9SAndroid Build Coastguard Worker            tl_padding = out[:, :, : pt + 1, : pl + 1]
8635*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tl_padding, x[:, :, :1, :1].expand_as(tl_padding))
8636*da0073e9SAndroid Build Coastguard Worker            tr_padding = out[:, :, : pt + 1, -pr - 1:]
8637*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tr_padding, x[:, :, :1, -1:].expand_as(tr_padding))
8638*da0073e9SAndroid Build Coastguard Worker            bl_padding = out[:, :, -pb - 1:, : pl + 1]
8639*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bl_padding, x[:, :, -1:, :1].expand_as(bl_padding))
8640*da0073e9SAndroid Build Coastguard Worker            br_padding = out[:, :, -pb - 1:, -pr - 1:]
8641*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(br_padding, x[:, :, -1:, -1:].expand_as(br_padding))
8642*da0073e9SAndroid Build Coastguard Worker
8643*da0073e9SAndroid Build Coastguard Worker            # backward center, edge
8644*da0073e9SAndroid Build Coastguard Worker            g = torch.randn_like(out)
8645*da0073e9SAndroid Build Coastguard Worker            out.backward(g)
8646*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 1:-1, 1:-1], g[:, :, pt + 1 : -pb - 1, pl + 1 : -pr - 1])
8647*da0073e9SAndroid Build Coastguard Worker
8648*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 1:-1, 0], g[:, :, pt + 1 : -pb - 1, : pl + 1].sum(-1))
8649*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 1:-1, -1], g[:, :, pt + 1 : -pb - 1, -pr - 1 :].sum(-1))
8650*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 0, 1:-1], g[:, :, : pt + 1, pl + 1 : -pr - 1].sum(-2))
8651*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, -1, 1:-1], g[:, :, -pb - 1 :, pl + 1 : -pr - 1].sum(-2))
8652*da0073e9SAndroid Build Coastguard Worker
8653*da0073e9SAndroid Build Coastguard Worker            # backward corner
8654*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 0, 0], g[:, :, : pt + 1, : pl + 1].sum((-2, -1)))
8655*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 0, -1], g[:, :, : pt + 1, -pr - 1 :].sum((-2, -1)))
8656*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, -1, 0], g[:, :, -pb - 1 :, : pl + 1].sum((-2, -1)))
8657*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, -1, -1], g[:, :, -pb - 1 :, -pr - 1 :].sum((-2, -1)))
8658*da0073e9SAndroid Build Coastguard Worker
8659*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("6GB")
8660*da0073e9SAndroid Build Coastguard Worker    def test_ReplicationPad3d_large(self, device):
8661*da0073e9SAndroid Build Coastguard Worker        shapes = ([1, 65736, 2, 2, 2], [65736, 1, 2, 2, 2])
8662*da0073e9SAndroid Build Coastguard Worker        pl, pr, pt, pbt, pf, pbk = 3, 4, 5, 6, 7, 8
8663*da0073e9SAndroid Build Coastguard Worker
8664*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
8665*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device, requires_grad=True)
8666*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.ReplicationPad3d((pl, pr, pt, pbt, pf, pbk))
8667*da0073e9SAndroid Build Coastguard Worker
8668*da0073e9SAndroid Build Coastguard Worker            # forward center
8669*da0073e9SAndroid Build Coastguard Worker            out = model(x)
8670*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[:, :, pf : -pbk, pt : -pbt, pl : -pr], x)
8671*da0073e9SAndroid Build Coastguard Worker
8672*da0073e9SAndroid Build Coastguard Worker            # backward center
8673*da0073e9SAndroid Build Coastguard Worker            g = torch.randn_like(out)
8674*da0073e9SAndroid Build Coastguard Worker            out.backward(g)
8675*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[:, :, 1:-1, 1:-1, 1:-1], g[:, :, pf + 1 : -pbk - 1, pt + 1 : -pbt - 1, pl + 1 : -pr - 1])
8676*da0073e9SAndroid Build Coastguard Worker
8677*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8678*da0073e9SAndroid Build Coastguard Worker    def test_Bilinear_empty(self, device):
8679*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Bilinear(20, 30, 40).to(device)
8680*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.randn(0, 10, 20, requires_grad=True, device=device)
8681*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.randn(0, 10, 30, requires_grad=True, device=device)
8682*da0073e9SAndroid Build Coastguard Worker
8683*da0073e9SAndroid Build Coastguard Worker        output = mod(inp1, inp2)
8684*da0073e9SAndroid Build Coastguard Worker        output.sum().backward()
8685*da0073e9SAndroid Build Coastguard Worker
8686*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp1, torch.zeros_like(inp1))
8687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp2, torch.zeros_like(inp2))
8688*da0073e9SAndroid Build Coastguard Worker
8689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp1.grad, torch.zeros_like(inp1))
8690*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp2.grad, torch.zeros_like(inp2))
8691*da0073e9SAndroid Build Coastguard Worker
8692*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8693*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8694*da0073e9SAndroid Build Coastguard Worker    def test_TransformerEncoderLayer_empty(self, device):
8695*da0073e9SAndroid Build Coastguard Worker        for training in (True, False):
8696*da0073e9SAndroid Build Coastguard Worker            for batch_first, input_shape in [(True, (0, 10, 512)),
8697*da0073e9SAndroid Build Coastguard Worker                                             (False, (10, 0, 512))]:
8698*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(*input_shape, device=device, dtype=torch.double)
8699*da0073e9SAndroid Build Coastguard Worker                encoder_layer = nn.TransformerEncoderLayer(
8700*da0073e9SAndroid Build Coastguard Worker                    d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8701*da0073e9SAndroid Build Coastguard Worker                if not training:
8702*da0073e9SAndroid Build Coastguard Worker                    encoder_layer = encoder_layer.eval()
8703*da0073e9SAndroid Build Coastguard Worker                    with torch.no_grad():
8704*da0073e9SAndroid Build Coastguard Worker                        _test_module_empty_input(self, encoder_layer, input, check_size=False, inference=True)
8705*da0073e9SAndroid Build Coastguard Worker                    if batch_first and not TEST_WITH_CROSSREF:
8706*da0073e9SAndroid Build Coastguard Worker                        with torch.no_grad():
8707*da0073e9SAndroid Build Coastguard Worker                            # A NestedTensor with no tensors inside it doesn't have dim 3 (or dim
8708*da0073e9SAndroid Build Coastguard Worker                            # 2, for that matter) so it can't hit the fast path, nor can we give a
8709*da0073e9SAndroid Build Coastguard Worker                            # result.
8710*da0073e9SAndroid Build Coastguard Worker                            with self.assertRaisesRegex(
8711*da0073e9SAndroid Build Coastguard Worker                                    AssertionError, 'MultiheadAttention does not support NestedTensor outside'):
8712*da0073e9SAndroid Build Coastguard Worker                                nt = torch.nested.nested_tensor([], device=device)
8713*da0073e9SAndroid Build Coastguard Worker                                _test_module_empty_input(self, encoder_layer, nt, check_size=False, inference=True)
8714*da0073e9SAndroid Build Coastguard Worker
8715*da0073e9SAndroid Build Coastguard Worker                            nt = torch.nested.nested_tensor([torch.rand(0, 512, device=device, dtype=torch.double)], device=device)
8716*da0073e9SAndroid Build Coastguard Worker                            _test_module_empty_input(self, encoder_layer, nt, check_size=False, inference=True)
8717*da0073e9SAndroid Build Coastguard Worker                else:
8718*da0073e9SAndroid Build Coastguard Worker                    _test_module_empty_input(self, encoder_layer, input, check_size=False)
8719*da0073e9SAndroid Build Coastguard Worker
8720*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8721*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8722*da0073e9SAndroid Build Coastguard Worker    def test_TransformerEncoder_empty(self, device):
8723*da0073e9SAndroid Build Coastguard Worker        for batch_first, input_shape in [(True, (0, 10, 512)),
8724*da0073e9SAndroid Build Coastguard Worker                                         (False, (10, 0, 512))]:
8725*da0073e9SAndroid Build Coastguard Worker            input = torch.rand(*input_shape, device=device, dtype=torch.double)
8726*da0073e9SAndroid Build Coastguard Worker            encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8727*da0073e9SAndroid Build Coastguard Worker            transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6).to(device)
8728*da0073e9SAndroid Build Coastguard Worker            _test_module_empty_input(self, transformer_encoder, input, check_size=False)
8729*da0073e9SAndroid Build Coastguard Worker
8730*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8731*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8732*da0073e9SAndroid Build Coastguard Worker    def test_TransformerDecoderLayer_empty(self, device):
8733*da0073e9SAndroid Build Coastguard Worker        for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
8734*da0073e9SAndroid Build Coastguard Worker                                                     (False, (10, 0, 512), (20, 0, 512))]:
8735*da0073e9SAndroid Build Coastguard Worker            memory = torch.rand(*memory_shape, device=device, dtype=torch.double)
8736*da0073e9SAndroid Build Coastguard Worker            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double)
8737*da0073e9SAndroid Build Coastguard Worker            decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8738*da0073e9SAndroid Build Coastguard Worker            self._test_module_empty_inputs(decoder_layer, [tgt, memory])
8739*da0073e9SAndroid Build Coastguard Worker
8740*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8741*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8742*da0073e9SAndroid Build Coastguard Worker    def test_TransformerDecoder_empty(self, device):
8743*da0073e9SAndroid Build Coastguard Worker        for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)),
8744*da0073e9SAndroid Build Coastguard Worker                                                     (False, (10, 0, 512), (20, 0, 512))]:
8745*da0073e9SAndroid Build Coastguard Worker            memory = torch.rand(*memory_shape, device=device, dtype=torch.double)
8746*da0073e9SAndroid Build Coastguard Worker            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double)
8747*da0073e9SAndroid Build Coastguard Worker            decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=batch_first, dtype=torch.double).to(device)
8748*da0073e9SAndroid Build Coastguard Worker            transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6).to(device)
8749*da0073e9SAndroid Build Coastguard Worker            self._test_module_empty_inputs(transformer_decoder, [tgt, memory])
8750*da0073e9SAndroid Build Coastguard Worker
8751*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1]
8752*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8753*da0073e9SAndroid Build Coastguard Worker    def test_Transformer_empty(self, device):
8754*da0073e9SAndroid Build Coastguard Worker        for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]:
8755*da0073e9SAndroid Build Coastguard Worker            transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, dtype=torch.double).to(device)
8756*da0073e9SAndroid Build Coastguard Worker            src = torch.rand(*src_shape, requires_grad=True, device=device, dtype=torch.double)
8757*da0073e9SAndroid Build Coastguard Worker            tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double)
8758*da0073e9SAndroid Build Coastguard Worker            self._test_module_empty_inputs(transformer_model, [src, tgt])
8759*da0073e9SAndroid Build Coastguard Worker
8760*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8761*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.complex64)
8762*da0073e9SAndroid Build Coastguard Worker    def test_ReflectionPad_empty(self, device, dtype):
8763*da0073e9SAndroid Build Coastguard Worker        for mod, inp in [
8764*da0073e9SAndroid Build Coastguard Worker                (torch.nn.ReflectionPad1d(2), torch.randn(0, 3, 10, device=device, dtype=dtype)),
8765*da0073e9SAndroid Build Coastguard Worker                (torch.nn.ReflectionPad2d(2), torch.randn(0, 3, 10, 10, device=device, dtype=dtype)),
8766*da0073e9SAndroid Build Coastguard Worker                (torch.nn.ReflectionPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]:
8767*da0073e9SAndroid Build Coastguard Worker            _test_module_empty_input(self, mod, inp, check_size=False)
8768*da0073e9SAndroid Build Coastguard Worker
8769*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '2D or 3D'):
8770*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.ReflectionPad1d(2)
8771*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(3, 0, 10, device=device, dtype=dtype)
8772*da0073e9SAndroid Build Coastguard Worker            mod(inp)
8773*da0073e9SAndroid Build Coastguard Worker
8774*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
8775*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.ReflectionPad2d(2)
8776*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(3, 0, 10, 10, device=device, dtype=dtype)
8777*da0073e9SAndroid Build Coastguard Worker            mod(inp)
8778*da0073e9SAndroid Build Coastguard Worker
8779*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '4D or 5D'):
8780*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.ReflectionPad3d(3)
8781*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
8782*da0073e9SAndroid Build Coastguard Worker            mod(inp)
8783*da0073e9SAndroid Build Coastguard Worker
8784*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA   # Test if CPU and GPU results match
8785*da0073e9SAndroid Build Coastguard Worker    def test_ReflectionPad2d_large(self, device):
8786*da0073e9SAndroid Build Coastguard Worker        shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6])
8787*da0073e9SAndroid Build Coastguard Worker        pad = (1, 2, 3, 4)
8788*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
8789*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device, requires_grad=True)
8790*da0073e9SAndroid Build Coastguard Worker            ref_x = x.detach().cpu().requires_grad_()
8791*da0073e9SAndroid Build Coastguard Worker
8792*da0073e9SAndroid Build Coastguard Worker            out = F.pad(x, pad, mode='reflect')
8793*da0073e9SAndroid Build Coastguard Worker            ref_out = F.pad(ref_x, pad, mode='reflect')
8794*da0073e9SAndroid Build Coastguard Worker
8795*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
8796*da0073e9SAndroid Build Coastguard Worker
8797*da0073e9SAndroid Build Coastguard Worker            g = torch.randn_like(out)
8798*da0073e9SAndroid Build Coastguard Worker            ref_g = g.cpu()
8799*da0073e9SAndroid Build Coastguard Worker
8800*da0073e9SAndroid Build Coastguard Worker            out.backward(g)
8801*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_g)
8802*da0073e9SAndroid Build Coastguard Worker
8803*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, ref_x.grad)
8804*da0073e9SAndroid Build Coastguard Worker
8805*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8806*da0073e9SAndroid Build Coastguard Worker    def test_LocalResponseNorm_empty(self, device):
8807*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.LocalResponseNorm(2).to(device)
8808*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 5, 24, 24, device=device)
8809*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
8810*da0073e9SAndroid Build Coastguard Worker
8811*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA   # Test if CPU and GPU results match
8812*da0073e9SAndroid Build Coastguard Worker    def test_ReflectionPad3d_large(self, device):
8813*da0073e9SAndroid Build Coastguard Worker        shapes = ([2, 1000, 7, 7, 7], [1000, 2, 7, 7, 7])
8814*da0073e9SAndroid Build Coastguard Worker        pad = (1, 2, 3, 4, 5, 6)
8815*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
8816*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device, requires_grad=True)
8817*da0073e9SAndroid Build Coastguard Worker            ref_x = x.detach().cpu().requires_grad_()
8818*da0073e9SAndroid Build Coastguard Worker
8819*da0073e9SAndroid Build Coastguard Worker            out = F.pad(x, pad, mode='reflect')
8820*da0073e9SAndroid Build Coastguard Worker            ref_out = F.pad(ref_x, pad, mode='reflect')
8821*da0073e9SAndroid Build Coastguard Worker
8822*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
8823*da0073e9SAndroid Build Coastguard Worker
8824*da0073e9SAndroid Build Coastguard Worker            g = torch.randn_like(out)
8825*da0073e9SAndroid Build Coastguard Worker            ref_g = g.cpu()
8826*da0073e9SAndroid Build Coastguard Worker
8827*da0073e9SAndroid Build Coastguard Worker            out.backward(g)
8828*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_g)
8829*da0073e9SAndroid Build Coastguard Worker
8830*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, ref_x.grad)
8831*da0073e9SAndroid Build Coastguard Worker
8832*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8833*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
8834*da0073e9SAndroid Build Coastguard Worker    def test_MarginLoss_empty(self, device, dtype):
8835*da0073e9SAndroid Build Coastguard Worker        for mod, x, y in [
8836*da0073e9SAndroid Build Coastguard Worker                (torch.nn.MultiMarginLoss().to(device),
8837*da0073e9SAndroid Build Coastguard Worker                 torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype),
8838*da0073e9SAndroid Build Coastguard Worker                 torch.ones(0, device=device).type(torch.long)),
8839*da0073e9SAndroid Build Coastguard Worker                (torch.nn.MultiLabelMarginLoss().to(device),
8840*da0073e9SAndroid Build Coastguard Worker                 torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype),
8841*da0073e9SAndroid Build Coastguard Worker                 torch.ones(0, 10, device=device).type(torch.long))]:
8842*da0073e9SAndroid Build Coastguard Worker
8843*da0073e9SAndroid Build Coastguard Worker            out = mod(x, y)
8844*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
8845*da0073e9SAndroid Build Coastguard Worker
8846*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, torch.zeros_like(x))
8847*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, torch.zeros_like(x))
8848*da0073e9SAndroid Build Coastguard Worker
8849*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected'):
8850*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(0, requires_grad=True, device=device, dtype=dtype)
8851*da0073e9SAndroid Build Coastguard Worker                y = torch.ones(10, device=device).type(torch.long)
8852*da0073e9SAndroid Build Coastguard Worker                mod(x, y)
8853*da0073e9SAndroid Build Coastguard Worker
8854*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected'):
8855*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(10, 0, requires_grad=True, device=device, dtype=dtype)
8856*da0073e9SAndroid Build Coastguard Worker                y = torch.ones(10, 0, device=device).type(torch.long)
8857*da0073e9SAndroid Build Coastguard Worker                mod(x, y)
8858*da0073e9SAndroid Build Coastguard Worker
8859*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8860*da0073e9SAndroid Build Coastguard Worker    def test_MarginLoss_warnings(self, device):
8861*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(128, 22, device=device)
8862*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.MultiMarginLoss()
8863*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((56, 128), device=device)
8864*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(22, (56,), device=device)
8865*da0073e9SAndroid Build Coastguard Worker        f = io.StringIO()
8866*da0073e9SAndroid Build Coastguard Worker        with contextlib.redirect_stderr(f):
8867*da0073e9SAndroid Build Coastguard Worker            out = model(x)
8868*da0073e9SAndroid Build Coastguard Worker            l = loss(out, targets)
8869*da0073e9SAndroid Build Coastguard Worker            l.backward()
8870*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(f.getvalue()) == 0)
8871*da0073e9SAndroid Build Coastguard Worker
8872*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
8873*da0073e9SAndroid Build Coastguard Worker    def test_Unfold_empty(self, device):
8874*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 3, 3, 4, device=device)
8875*da0073e9SAndroid Build Coastguard Worker        unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
8876*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, unfold, inp, check_size=False)
8877*da0073e9SAndroid Build Coastguard Worker
8878*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Expected 3D or 4D'):
8879*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(3, 0, 3, 4, device=device)
8880*da0073e9SAndroid Build Coastguard Worker            unfold = torch.nn.Unfold(kernel_size=(2, 3)).to(device)
8881*da0073e9SAndroid Build Coastguard Worker            unfold(inp)
8882*da0073e9SAndroid Build Coastguard Worker
8883*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8884*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
8885*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
8886*da0073e9SAndroid Build Coastguard Worker    def test_rnn_fused(self, device, dtype):
8887*da0073e9SAndroid Build Coastguard Worker
8888*da0073e9SAndroid Build Coastguard Worker        def copy_rnn(rnn1, rnn2):
8889*da0073e9SAndroid Build Coastguard Worker            for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
8890*da0073e9SAndroid Build Coastguard Worker                for x, y in zip(x_layer, y_layer):
8891*da0073e9SAndroid Build Coastguard Worker                    x.data.copy_(y.data)
8892*da0073e9SAndroid Build Coastguard Worker
8893*da0073e9SAndroid Build Coastguard Worker        def check_rnn_grads(rnn1, rnn2):
8894*da0073e9SAndroid Build Coastguard Worker            for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
8895*da0073e9SAndroid Build Coastguard Worker                for x, y in zip(x_layer, y_layer):
8896*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x.grad, y.grad, atol=5e-5, rtol=0)
8897*da0073e9SAndroid Build Coastguard Worker
8898*da0073e9SAndroid Build Coastguard Worker        input_size = 10
8899*da0073e9SAndroid Build Coastguard Worker        hidden_size = 6
8900*da0073e9SAndroid Build Coastguard Worker        num_layers = 2
8901*da0073e9SAndroid Build Coastguard Worker        seq_length = 7
8902*da0073e9SAndroid Build Coastguard Worker        batch = 6
8903*da0073e9SAndroid Build Coastguard Worker        input_val = torch.randn(seq_length, batch, input_size, dtype=dtype)
8904*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(seq_length, batch, hidden_size, dtype=dtype)
8905*da0073e9SAndroid Build Coastguard Worker        hx_val = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
8906*da0073e9SAndroid Build Coastguard Worker        grad_hy = torch.randn(num_layers, batch, hidden_size, dtype=dtype)
8907*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=False, allow_tf32=None):
8908*da0073e9SAndroid Build Coastguard Worker            for module in (nn.GRU, nn.LSTM):
8909*da0073e9SAndroid Build Coastguard Worker                for bias in (True, False):
8910*da0073e9SAndroid Build Coastguard Worker                    rnn = module(input_size, hidden_size, num_layers, bias=bias).to(dtype)
8911*da0073e9SAndroid Build Coastguard Worker                    rnn_device = module(input_size, hidden_size, num_layers, bias=bias).to(device, dtype)
8912*da0073e9SAndroid Build Coastguard Worker                    copy_rnn(rnn, rnn_device)
8913*da0073e9SAndroid Build Coastguard Worker
8914*da0073e9SAndroid Build Coastguard Worker                    is_lstm = isinstance(rnn, nn.LSTM)
8915*da0073e9SAndroid Build Coastguard Worker                    if is_lstm:
8916*da0073e9SAndroid Build Coastguard Worker                        hx = (hx_val.clone().requires_grad_(True),
8917*da0073e9SAndroid Build Coastguard Worker                              hx_val.clone().add(1).requires_grad_(True))
8918*da0073e9SAndroid Build Coastguard Worker                        hx_device = (hx_val.clone().to(device).requires_grad_(True),
8919*da0073e9SAndroid Build Coastguard Worker                                     hx_val.clone().to(device).add(1).requires_grad_(True))
8920*da0073e9SAndroid Build Coastguard Worker                    else:
8921*da0073e9SAndroid Build Coastguard Worker                        hx = hx_val.clone().requires_grad_(True)
8922*da0073e9SAndroid Build Coastguard Worker                        hx_device = hx_val.clone().to(device).requires_grad_(True)
8923*da0073e9SAndroid Build Coastguard Worker
8924*da0073e9SAndroid Build Coastguard Worker                    inp = input_val.clone().requires_grad_(True)
8925*da0073e9SAndroid Build Coastguard Worker                    inp_cu = input_val.clone().to(device).requires_grad_(True)
8926*da0073e9SAndroid Build Coastguard Worker                    output1, hy1 = rnn(inp, hx)
8927*da0073e9SAndroid Build Coastguard Worker                    output2, hy2 = rnn_device(inp_cu, hx_device)
8928*da0073e9SAndroid Build Coastguard Worker                    if is_lstm:
8929*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.backward(
8930*da0073e9SAndroid Build Coastguard Worker                            [output1, hy1[0], hy1[1]], [grad_output, grad_hy, grad_hy + 1]
8931*da0073e9SAndroid Build Coastguard Worker                        )
8932*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.backward(
8933*da0073e9SAndroid Build Coastguard Worker                            [output2, hy2[0], hy2[1]],
8934*da0073e9SAndroid Build Coastguard Worker                            [grad_output.to(device), grad_hy.to(device), (grad_hy + 1).to(device)]
8935*da0073e9SAndroid Build Coastguard Worker                        )
8936*da0073e9SAndroid Build Coastguard Worker                    else:
8937*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.backward([output1, hy1], [grad_output, grad_hy])
8938*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.backward([output2, hy2], [grad_output.to(device), grad_hy.to(device)])
8939*da0073e9SAndroid Build Coastguard Worker
8940*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output1, output2)
8941*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(hy1, hy2)
8942*da0073e9SAndroid Build Coastguard Worker
8943*da0073e9SAndroid Build Coastguard Worker                    check_rnn_grads(rnn, rnn_device)
8944*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(inp.grad, inp_cu.grad)
8945*da0073e9SAndroid Build Coastguard Worker                    if is_lstm:
8946*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(hx[0].grad, hx_device[0].grad)
8947*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(hx[1].grad, hx_device[1].grad)
8948*da0073e9SAndroid Build Coastguard Worker                    else:
8949*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(hx.grad, hx_device.grad)
8950*da0073e9SAndroid Build Coastguard Worker
8951*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.float)
8952*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
8953*da0073e9SAndroid Build Coastguard Worker    def test_BatchNorm_empty(self, device, dtype):
8954*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.BatchNorm2d(3).to(device)
8955*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 3, 2, 2, device=device, dtype=dtype)
8956*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp)
8957*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
8958*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
8959*da0073e9SAndroid Build Coastguard Worker                _test_module_empty_input(self, mod, inp)
8960*da0073e9SAndroid Build Coastguard Worker
8961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod.running_mean, torch.tensor([0., 0, 0], device=device))
8962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod.running_var, torch.tensor([1., 1, 1], device=device))
8963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device))
8964*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device))
8965*da0073e9SAndroid Build Coastguard Worker
8966*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
8967*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('16GB')
8968*da0073e9SAndroid Build Coastguard Worker    def test_prelu_backward_32bit_indexing(self, device):
8969*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.PReLU().cuda().half()
8970*da0073e9SAndroid Build Coastguard Worker        input_ = torch.ones((1024, 1024, 1024, 2), dtype=torch.half, device=device)
8971*da0073e9SAndroid Build Coastguard Worker        output = m(input_)
8972*da0073e9SAndroid Build Coastguard Worker        output.backward(input_)
8973*da0073e9SAndroid Build Coastguard Worker
8974*da0073e9SAndroid Build Coastguard Worker    def test_linear_empty(self, device):
8975*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Linear(7, 7).to(device)
8976*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 7, device=device)
8977*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp)
8978*da0073e9SAndroid Build Coastguard Worker
8979*da0073e9SAndroid Build Coastguard Worker    def test_one_hot(self, device):
8980*da0073e9SAndroid Build Coastguard Worker        # cuda throws device assert for invalid data
8981*da0073e9SAndroid Build Coastguard Worker        # xla ignores out of bound indices
8982*da0073e9SAndroid Build Coastguard Worker        if self.device_type not in ('cuda', 'mps', 'xla'):
8983*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
8984*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
8985*da0073e9SAndroid Build Coastguard Worker
8986*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
8987*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
8988*da0073e9SAndroid Build Coastguard Worker
8989*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
8990*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([[0, 0, 0, 1, 0],
8991*da0073e9SAndroid Build Coastguard Worker                                 [0, 0, 0, 0, 1],
8992*da0073e9SAndroid Build Coastguard Worker                                 [0, 1, 0, 0, 0],
8993*da0073e9SAndroid Build Coastguard Worker                                 [1, 0, 0, 0, 0]], device=device)
8994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, expected)
8995*da0073e9SAndroid Build Coastguard Worker
8996*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
8997*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([[0, 0, 0, 1, 0],
8998*da0073e9SAndroid Build Coastguard Worker                                 [0, 0, 0, 0, 1],
8999*da0073e9SAndroid Build Coastguard Worker                                 [0, 1, 0, 0, 0],
9000*da0073e9SAndroid Build Coastguard Worker                                 [1, 0, 0, 0, 0]], device=device)
9001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, expected)
9002*da0073e9SAndroid Build Coastguard Worker
9003*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
9004*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([[0, 0, 0, 1, 0, 0],
9005*da0073e9SAndroid Build Coastguard Worker                                 [0, 0, 0, 0, 1, 0],
9006*da0073e9SAndroid Build Coastguard Worker                                 [0, 1, 0, 0, 0, 0],
9007*da0073e9SAndroid Build Coastguard Worker                                 [1, 0, 0, 0, 0, 0]], device=device)
9008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, expected)
9009*da0073e9SAndroid Build Coastguard Worker
9010*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
9011*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([[[0, 0, 0, 1, 0],
9012*da0073e9SAndroid Build Coastguard Worker                                  [0, 0, 0, 0, 1]],
9013*da0073e9SAndroid Build Coastguard Worker                                 [[0, 1, 0, 0, 0],
9014*da0073e9SAndroid Build Coastguard Worker                                  [1, 0, 0, 0, 0]]], device=device)
9015*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, expected)
9016*da0073e9SAndroid Build Coastguard Worker
9017*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
9018*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([0, 0, 0, 0, 1], device=device)
9019*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, expected)
9020*da0073e9SAndroid Build Coastguard Worker
9021*da0073e9SAndroid Build Coastguard Worker        t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
9022*da0073e9SAndroid Build Coastguard Worker        expected = torch.empty([4, 0, 100], dtype=torch.long)
9023*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, expected)
9024*da0073e9SAndroid Build Coastguard Worker
9025*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
9026*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
9027*da0073e9SAndroid Build Coastguard Worker
9028*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
9029*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
9030*da0073e9SAndroid Build Coastguard Worker
9031*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764
9032*da0073e9SAndroid Build Coastguard Worker    def test_nn_empty(self, device):
9033*da0073e9SAndroid Build Coastguard Worker        # One off tests to ensure scalars from nn.yaml are properly applied
9034*da0073e9SAndroid Build Coastguard Worker        def verify_scalars(input, output):
9035*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.shape, output.shape)
9036*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(0, output.numel())
9037*da0073e9SAndroid Build Coastguard Worker
9038*da0073e9SAndroid Build Coastguard Worker        for input_shape in [(0), (0, 2)]:
9039*da0073e9SAndroid Build Coastguard Worker            for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
9040*da0073e9SAndroid Build Coastguard Worker                           torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
9041*da0073e9SAndroid Build Coastguard Worker                           torch.nn.Tanh]:
9042*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(input_shape, device=device, requires_grad=True)
9043*da0073e9SAndroid Build Coastguard Worker                m = module()
9044*da0073e9SAndroid Build Coastguard Worker                output = m(input)
9045*da0073e9SAndroid Build Coastguard Worker                verify_scalars(input, output)
9046*da0073e9SAndroid Build Coastguard Worker
9047*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764
9048*da0073e9SAndroid Build Coastguard Worker    def test_nn_scalars(self, device):
9049*da0073e9SAndroid Build Coastguard Worker        # One off tests to ensure scalars from nn.yaml are properly applied
9050*da0073e9SAndroid Build Coastguard Worker        def verify_scalars(input, output):
9051*da0073e9SAndroid Build Coastguard Worker            if input.dim() == 0:
9052*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((), output.shape)
9053*da0073e9SAndroid Build Coastguard Worker            else:
9054*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual((), output.shape)
9055*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
9056*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.shape, input.grad.shape)
9057*da0073e9SAndroid Build Coastguard Worker
9058*da0073e9SAndroid Build Coastguard Worker        for input_shape in [(5, 6), ()]:
9059*da0073e9SAndroid Build Coastguard Worker            for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid,
9060*da0073e9SAndroid Build Coastguard Worker                           torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid,
9061*da0073e9SAndroid Build Coastguard Worker                           torch.nn.Tanh]:
9062*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(input_shape, device=device, requires_grad=True)
9063*da0073e9SAndroid Build Coastguard Worker                m = module()
9064*da0073e9SAndroid Build Coastguard Worker                output = m(input)
9065*da0073e9SAndroid Build Coastguard Worker                verify_scalars(input, output)
9066*da0073e9SAndroid Build Coastguard Worker
9067*da0073e9SAndroid Build Coastguard Worker    def test_nn_scalars_reductions(self, device):
9068*da0073e9SAndroid Build Coastguard Worker        # One off tests to ensure scalars from nn.yaml are properly applied
9069*da0073e9SAndroid Build Coastguard Worker        def verify_reduction_scalars(input, reduction, output):
9070*da0073e9SAndroid Build Coastguard Worker            if reduction != 'none' or input.dim() == 0:
9071*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((), output.shape)
9072*da0073e9SAndroid Build Coastguard Worker            else:
9073*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual((), output.shape)
9074*da0073e9SAndroid Build Coastguard Worker            output.sum().backward()
9075*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.shape, input.grad.shape)
9076*da0073e9SAndroid Build Coastguard Worker
9077*da0073e9SAndroid Build Coastguard Worker        for input_shape in [(5, 6), ()]:
9078*da0073e9SAndroid Build Coastguard Worker            for reduction in ['none', 'mean', 'sum']:
9079*da0073e9SAndroid Build Coastguard Worker                for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss,
9080*da0073e9SAndroid Build Coastguard Worker                               torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]:
9081*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(input_shape, device=device, requires_grad=True)
9082*da0073e9SAndroid Build Coastguard Worker                    target = torch.empty(input_shape, device=device).random_(2)
9083*da0073e9SAndroid Build Coastguard Worker                    sigmoid = nn.Sigmoid()
9084*da0073e9SAndroid Build Coastguard Worker
9085*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(input_shape, device=device, requires_grad=True)
9086*da0073e9SAndroid Build Coastguard Worker                    m = module(reduction=reduction)
9087*da0073e9SAndroid Build Coastguard Worker                    output = m(sigmoid(input), target)
9088*da0073e9SAndroid Build Coastguard Worker                    verify_reduction_scalars(input, reduction, output)
9089*da0073e9SAndroid Build Coastguard Worker
9090*da0073e9SAndroid Build Coastguard Worker    # verify that bogus reduction strings are errors
9091*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
9092*da0073e9SAndroid Build Coastguard Worker    def test_invalid_reduction_strings(self, device):
9093*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 5, requires_grad=True, device=device)
9094*da0073e9SAndroid Build Coastguard Worker        cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat)
9095*da0073e9SAndroid Build Coastguard Worker        target = torch.tensor([1, 0, 4], device=device)
9096*da0073e9SAndroid Build Coastguard Worker        var = torch.ones(size=input.size(), requires_grad=True, device=device)
9097*da0073e9SAndroid Build Coastguard Worker
9098*da0073e9SAndroid Build Coastguard Worker        for reduction in ['none', 'invalid']:
9099*da0073e9SAndroid Build Coastguard Worker            def v(fn):
9100*da0073e9SAndroid Build Coastguard Worker                if reduction == 'invalid':
9101*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(ValueError, lambda: fn())
9102*da0073e9SAndroid Build Coastguard Worker                else:
9103*da0073e9SAndroid Build Coastguard Worker                    fn()
9104*da0073e9SAndroid Build Coastguard Worker
9105*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.nll_loss(input, target, reduction=reduction))
9106*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.cross_entropy(input, target, reduction=reduction))
9107*da0073e9SAndroid Build Coastguard Worker
9108*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.kl_div(input, input, reduction=reduction))
9109*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.huber_loss(input, input, reduction=reduction))
9110*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.smooth_l1_loss(input, input, reduction=reduction))
9111*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.l1_loss(input, input, reduction=reduction))
9112*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.l1_loss(cinput, cinput, reduction=reduction))
9113*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.mse_loss(input, input, reduction=reduction))
9114*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction))
9115*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.poisson_nll_loss(input, input, reduction=reduction))
9116*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.gaussian_nll_loss(input, input, var, reduction=reduction))
9117*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input.gt(0).to(torch.get_default_dtype()), reduction=reduction))
9118*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction))
9119*da0073e9SAndroid Build Coastguard Worker
9120*da0073e9SAndroid Build Coastguard Worker            zeros = torch.zeros_like(input).to(torch.int64)
9121*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.multilabel_soft_margin_loss(input, zeros, reduction=reduction))
9122*da0073e9SAndroid Build Coastguard Worker
9123*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction))
9124*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.triplet_margin_with_distance_loss(input, input, input, reduction=reduction))
9125*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction))
9126*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction))
9127*da0073e9SAndroid Build Coastguard Worker
9128*da0073e9SAndroid Build Coastguard Worker            log_probs = torch.randn(50, 16, 20, requires_grad=True, device=device).log_softmax(2)
9129*da0073e9SAndroid Build Coastguard Worker            targets = torch.randint(1, 20, (16, 30), dtype=torch.long, device=device)
9130*da0073e9SAndroid Build Coastguard Worker            input_lengths = torch.full((16,), 50, dtype=torch.long, device=device)
9131*da0073e9SAndroid Build Coastguard Worker            target_lengths = torch.randint(10, 30, (16,), dtype=torch.long, device=device)
9132*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction=reduction))
9133*da0073e9SAndroid Build Coastguard Worker
9134*da0073e9SAndroid Build Coastguard Worker            # FIXME: should we allow derivatives on these?
9135*da0073e9SAndroid Build Coastguard Worker            v(lambda: F.soft_margin_loss(input, input.sign().detach(), reduction=reduction))
9136*da0073e9SAndroid Build Coastguard Worker
9137*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
9138*da0073e9SAndroid Build Coastguard Worker    def test_smooth_l1_loss_vs_huber_loss(self, device):
9139*da0073e9SAndroid Build Coastguard Worker        def _make_test_tensor(shape, contiguous=True):
9140*da0073e9SAndroid Build Coastguard Worker            if contiguous:
9141*da0073e9SAndroid Build Coastguard Worker                test_tensor = torch.randn(shape, device=device)
9142*da0073e9SAndroid Build Coastguard Worker            else:
9143*da0073e9SAndroid Build Coastguard Worker                # Select every other element in the innermost dimension to
9144*da0073e9SAndroid Build Coastguard Worker                # make it non-contiguous.
9145*da0073e9SAndroid Build Coastguard Worker                doubled_shape = list(shape)
9146*da0073e9SAndroid Build Coastguard Worker                doubled_shape[-1] *= 2
9147*da0073e9SAndroid Build Coastguard Worker                test_tensor = torch.randn(doubled_shape, device=device)
9148*da0073e9SAndroid Build Coastguard Worker                test_tensor = test_tensor[..., ::2]
9149*da0073e9SAndroid Build Coastguard Worker            return test_tensor
9150*da0073e9SAndroid Build Coastguard Worker
9151*da0073e9SAndroid Build Coastguard Worker        def _test_smooth_l1_loss_vs_huber_loss_helper(input, target, beta, require_equal):
9152*da0073e9SAndroid Build Coastguard Worker            for reduction in ['mean', 'sum', 'none']:
9153*da0073e9SAndroid Build Coastguard Worker                smooth_l1 = torch.nn.SmoothL1Loss(beta=beta, reduction=reduction)
9154*da0073e9SAndroid Build Coastguard Worker                # beta hyper-parameter is called delta for Huber
9155*da0073e9SAndroid Build Coastguard Worker                huber = torch.nn.HuberLoss(delta=beta, reduction=reduction)
9156*da0073e9SAndroid Build Coastguard Worker                smooth_l1_loss = smooth_l1(input, target)
9157*da0073e9SAndroid Build Coastguard Worker                huber_loss = huber(input, target)
9158*da0073e9SAndroid Build Coastguard Worker
9159*da0073e9SAndroid Build Coastguard Worker                if require_equal:
9160*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(smooth_l1_loss, huber_loss)
9161*da0073e9SAndroid Build Coastguard Worker                else:
9162*da0073e9SAndroid Build Coastguard Worker                    # Huber loss should be larger than smooth L1 loss by a factor of beta.
9163*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(smooth_l1_loss * beta, huber_loss)
9164*da0073e9SAndroid Build Coastguard Worker
9165*da0073e9SAndroid Build Coastguard Worker        def _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta, require_equal):
9166*da0073e9SAndroid Build Coastguard Worker            # Test the non-vectorized case.
9167*da0073e9SAndroid Build Coastguard Worker            shape = (2, 2)
9168*da0073e9SAndroid Build Coastguard Worker            _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape),
9169*da0073e9SAndroid Build Coastguard Worker                                                      target=_make_test_tensor(shape),
9170*da0073e9SAndroid Build Coastguard Worker                                                      beta=beta,
9171*da0073e9SAndroid Build Coastguard Worker                                                      require_equal=require_equal)
9172*da0073e9SAndroid Build Coastguard Worker
9173*da0073e9SAndroid Build Coastguard Worker            # Test the vectorized case (innermost dim > 32).
9174*da0073e9SAndroid Build Coastguard Worker            shape = (64, 64)
9175*da0073e9SAndroid Build Coastguard Worker            _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape),
9176*da0073e9SAndroid Build Coastguard Worker                                                      target=_make_test_tensor(shape),
9177*da0073e9SAndroid Build Coastguard Worker                                                      beta=beta,
9178*da0073e9SAndroid Build Coastguard Worker                                                      require_equal=require_equal)
9179*da0073e9SAndroid Build Coastguard Worker
9180*da0073e9SAndroid Build Coastguard Worker            # Test the non-contiguous case.
9181*da0073e9SAndroid Build Coastguard Worker            _test_smooth_l1_loss_vs_huber_loss_helper(input=_make_test_tensor(shape, contiguous=False),
9182*da0073e9SAndroid Build Coastguard Worker                                                      target=_make_test_tensor(shape, contiguous=False),
9183*da0073e9SAndroid Build Coastguard Worker                                                      beta=beta,
9184*da0073e9SAndroid Build Coastguard Worker                                                      require_equal=require_equal)
9185*da0073e9SAndroid Build Coastguard Worker
9186*da0073e9SAndroid Build Coastguard Worker        def test_equal_when_beta_is_one():
9187*da0073e9SAndroid Build Coastguard Worker            _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=1.0, require_equal=True)
9188*da0073e9SAndroid Build Coastguard Worker
9189*da0073e9SAndroid Build Coastguard Worker        def test_unequal_when_beta_is_less_than_one():
9190*da0073e9SAndroid Build Coastguard Worker            _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=0.5, require_equal=False)
9191*da0073e9SAndroid Build Coastguard Worker
9192*da0073e9SAndroid Build Coastguard Worker        def test_unequal_when_beta_is_greater_than_one():
9193*da0073e9SAndroid Build Coastguard Worker            _test_smooth_l1_loss_vs_huber_loss_multi_input_helper(beta=1.5, require_equal=False)
9194*da0073e9SAndroid Build Coastguard Worker
9195*da0073e9SAndroid Build Coastguard Worker        test_equal_when_beta_is_one()
9196*da0073e9SAndroid Build Coastguard Worker        test_unequal_when_beta_is_less_than_one()
9197*da0073e9SAndroid Build Coastguard Worker        test_unequal_when_beta_is_greater_than_one()
9198*da0073e9SAndroid Build Coastguard Worker
9199*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
9200*da0073e9SAndroid Build Coastguard Worker    def test_smooth_l1_loss_bfloat16(self, device):
9201*da0073e9SAndroid Build Coastguard Worker        def test_dtype(fn, input, target, dtype):
9202*da0073e9SAndroid Build Coastguard Worker            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
9203*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().float().requires_grad_(True)
9204*da0073e9SAndroid Build Coastguard Worker            target = target.detach().clone().to(dtype=dtype)
9205*da0073e9SAndroid Build Coastguard Worker            target2 = target.detach().clone().float()
9206*da0073e9SAndroid Build Coastguard Worker            out = fn(input, target)
9207*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
9208*da0073e9SAndroid Build Coastguard Worker            out2 = fn(input2, target2)
9209*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
9210*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.dtype, dtype)
9211*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.dtype, dtype)
9212*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2, exact_dtype=False)
9213*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input2.grad, exact_dtype=False)
9214*da0073e9SAndroid Build Coastguard Worker
9215*da0073e9SAndroid Build Coastguard Worker        def func(device):
9216*da0073e9SAndroid Build Coastguard Worker            return nn.SmoothL1Loss().to(device=device)
9217*da0073e9SAndroid Build Coastguard Worker
9218*da0073e9SAndroid Build Coastguard Worker        shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 128, 128]]
9219*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
9220*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device, requires_grad=True)
9221*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(shape, device=device)
9222*da0073e9SAndroid Build Coastguard Worker            test_dtype(func(device), x, t, torch.bfloat16)
9223*da0073e9SAndroid Build Coastguard Worker
9224*da0073e9SAndroid Build Coastguard Worker    # We don't want to make propagating NaN a hard requirement on ops, but for
9225*da0073e9SAndroid Build Coastguard Worker    # these easy ones, we should make them do so.
9226*da0073e9SAndroid Build Coastguard Worker    # MPS: NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764
9227*da0073e9SAndroid Build Coastguard Worker    # MPS: NotImplementedError: aten::hardshrink.out https://github.com/pytorch/pytorch/issues/77764
9228*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS
9229*da0073e9SAndroid Build Coastguard Worker    def test_nonlinearity_propagate_nan(self, device):
9230*da0073e9SAndroid Build Coastguard Worker        def test(nonlinearity, *args, **kwargs):
9231*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([nan], device=device)
9232*da0073e9SAndroid Build Coastguard Worker            fn = getattr(F, nonlinearity)
9233*da0073e9SAndroid Build Coastguard Worker            try:
9234*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isnan(fn(x, *args, **kwargs).item()))
9235*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
9236*da0073e9SAndroid Build Coastguard Worker                if 'not implemented' not in str(e):
9237*da0073e9SAndroid Build Coastguard Worker                    raise
9238*da0073e9SAndroid Build Coastguard Worker
9239*da0073e9SAndroid Build Coastguard Worker        test('relu')
9240*da0073e9SAndroid Build Coastguard Worker        test('relu', inplace=True)
9241*da0073e9SAndroid Build Coastguard Worker        test('relu6')
9242*da0073e9SAndroid Build Coastguard Worker        test('elu')
9243*da0073e9SAndroid Build Coastguard Worker        test('selu')
9244*da0073e9SAndroid Build Coastguard Worker        test('celu')
9245*da0073e9SAndroid Build Coastguard Worker        test('rrelu')
9246*da0073e9SAndroid Build Coastguard Worker        test('rrelu', inplace=True)
9247*da0073e9SAndroid Build Coastguard Worker        test('hardtanh')
9248*da0073e9SAndroid Build Coastguard Worker        test('tanh')
9249*da0073e9SAndroid Build Coastguard Worker        test('sigmoid')
9250*da0073e9SAndroid Build Coastguard Worker        test('logsigmoid')
9251*da0073e9SAndroid Build Coastguard Worker        test('hardshrink')
9252*da0073e9SAndroid Build Coastguard Worker        test('tanhshrink')
9253*da0073e9SAndroid Build Coastguard Worker        test('softsign')
9254*da0073e9SAndroid Build Coastguard Worker        test('softmin', 0)
9255*da0073e9SAndroid Build Coastguard Worker        test('softmax', 0)
9256*da0073e9SAndroid Build Coastguard Worker        test('log_softmax', 0)
9257*da0073e9SAndroid Build Coastguard Worker        test('leaky_relu', 0.2)
9258*da0073e9SAndroid Build Coastguard Worker        test('threshold', 3, 2)
9259*da0073e9SAndroid Build Coastguard Worker        test('threshold', 3, 2, inplace=True)
9260*da0073e9SAndroid Build Coastguard Worker
9261*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: float64 the MPS framework doesn't support float64
9262*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ["nearest-exact", "nearest"])
9263*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest1d(self, device, mode):
9264*da0073e9SAndroid Build Coastguard Worker        # Forward AD does not support XLA because XLA tensors don't have storage
9265*da0073e9SAndroid Build Coastguard Worker        check_forward_ad = torch.device(device).type != 'xla'
9266*da0073e9SAndroid Build Coastguard Worker
9267*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(size=4, mode=mode)
9268*da0073e9SAndroid Build Coastguard Worker        in_t = torch.ones(1, 1, 2, device=device, dtype=torch.double)
9269*da0073e9SAndroid Build Coastguard Worker        in_uint8_t = torch.ones(1, 1, 2, dtype=torch.uint8, device=device)
9270*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
9271*da0073e9SAndroid Build Coastguard Worker            out_t = m(in_t)
9272*da0073e9SAndroid Build Coastguard Worker            out_uint8_t = m(in_uint8_t)
9273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(1, 1, 4, device=device, dtype=torch.double), out_t.data)
9274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(1, 1, 4, dtype=torch.uint8, device=device), out_uint8_t.data)
9275*da0073e9SAndroid Build Coastguard Worker
9276*da0073e9SAndroid Build Coastguard Worker        # Checks upsampling
9277*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 1, 2, requires_grad=True, device=device, dtype=torch.double)
9278*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad)
9279*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9280*da0073e9SAndroid Build Coastguard Worker
9281*da0073e9SAndroid Build Coastguard Worker        # Checks downsampling
9282*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 1, 20, requires_grad=True, device=device, dtype=torch.double)
9283*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: F.interpolate(x, 11, mode=mode), [input], check_forward_ad=check_forward_ad)
9284*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9285*da0073e9SAndroid Build Coastguard Worker
9286*da0073e9SAndroid Build Coastguard Worker        # consistency CUDA/CPU check
9287*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cuda':
9288*da0073e9SAndroid Build Coastguard Worker            input_cuda = torch.randn(1, 1, 20, device=device, dtype=torch.double)
9289*da0073e9SAndroid Build Coastguard Worker            input_cpu = input_cuda.cpu()
9290*da0073e9SAndroid Build Coastguard Worker            output_cuda = F.interpolate(input_cuda, 4, mode=mode)
9291*da0073e9SAndroid Build Coastguard Worker            output_cpu = F.interpolate(input_cpu, 4, mode=mode)
9292*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cuda.cpu(), output_cpu)
9293*da0073e9SAndroid Build Coastguard Worker
9294*da0073e9SAndroid Build Coastguard Worker            output_cuda = F.interpolate(input_cuda, 24, mode=mode)
9295*da0073e9SAndroid Build Coastguard Worker            output_cpu = F.interpolate(input_cpu, 24, mode=mode)
9296*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cuda.cpu(), output_cpu)
9297*da0073e9SAndroid Build Coastguard Worker
9298*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9299*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest1d_correctness(self, device, isize, osize):
9300*da0073e9SAndroid Build Coastguard Worker        # Here we check if output matches OpenCV's INTER_NEAREST-like result
9301*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
9302*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(
9303*da0073e9SAndroid Build Coastguard Worker            in_t, size=(osize, ), recompute_scale_factor=False, mode="nearest"
9304*da0073e9SAndroid Build Coastguard Worker        )
9305*da0073e9SAndroid Build Coastguard Worker        # compute expected output as OpenCV
9306*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.zeros(osize, dtype=torch.float).unsqueeze(0).unsqueeze(0)
9307*da0073e9SAndroid Build Coastguard Worker        scale = 1.0 * isize / osize
9308*da0073e9SAndroid Build Coastguard Worker        for o in range(osize):
9309*da0073e9SAndroid Build Coastguard Worker            i_f32 = o * scale
9310*da0073e9SAndroid Build Coastguard Worker            i = int(i_f32)
9311*da0073e9SAndroid Build Coastguard Worker            expected_out[0, 0, o] = in_t[0, 0, i]
9312*da0073e9SAndroid Build Coastguard Worker        expected_out = expected_out.to(device=device)
9313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out)
9314*da0073e9SAndroid Build Coastguard Worker
9315*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearestExact1d_rescale(self, device):
9316*da0073e9SAndroid Build Coastguard Worker        # Checks https://github.com/pytorch/pytorch/issues/62237
9317*da0073e9SAndroid Build Coastguard Worker        isize = 20
9318*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
9319*da0073e9SAndroid Build Coastguard Worker        # for s in [1.00001, 0.99999]:  # 0.9999 case is broken
9320*da0073e9SAndroid Build Coastguard Worker        # See issue: https://github.com/pytorch/pytorch/issues/62396
9321*da0073e9SAndroid Build Coastguard Worker        for s in [1.00001, ]:
9322*da0073e9SAndroid Build Coastguard Worker            out_t = F.interpolate(
9323*da0073e9SAndroid Build Coastguard Worker                in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact"
9324*da0073e9SAndroid Build Coastguard Worker            )
9325*da0073e9SAndroid Build Coastguard Worker            expected_out = in_t
9326*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_t, expected_out, msg=f"scale: {s}")
9327*da0073e9SAndroid Build Coastguard Worker
9328*da0073e9SAndroid Build Coastguard Worker        # checks data duplication if output_size == 2 * input_size
9329*da0073e9SAndroid Build Coastguard Worker        # for s in [2.00001, 1.99999]:  # 1.99999 case is broken
9330*da0073e9SAndroid Build Coastguard Worker        # See issue: https://github.com/pytorch/pytorch/issues/62396
9331*da0073e9SAndroid Build Coastguard Worker        for s in [2.00001, ]:
9332*da0073e9SAndroid Build Coastguard Worker            out_t = F.interpolate(
9333*da0073e9SAndroid Build Coastguard Worker                in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact"
9334*da0073e9SAndroid Build Coastguard Worker            )
9335*da0073e9SAndroid Build Coastguard Worker            # input is [[[0, 1, 2, 3, ..., 9]]]
9336*da0073e9SAndroid Build Coastguard Worker            # expected out is [[[0, 0, 1, 1, 2, 2, ..., 9, 9]]]
9337*da0073e9SAndroid Build Coastguard Worker            expected_out = in_t.repeat_interleave(2, dim=-1)
9338*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_t, expected_out)
9339*da0073e9SAndroid Build Coastguard Worker
9340*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # Partially passes https://github.com/pytorch/pytorch/issues/134430
9341*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9342*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearestExact1d_correctness(self, device, isize, osize):
9343*da0073e9SAndroid Build Coastguard Worker        # Here we check if output matches Scikit-Image/Scipy-like result
9344*da0073e9SAndroid Build Coastguard Worker        # Checks https://github.com/pytorch/pytorch/issues/34808
9345*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
9346*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(
9347*da0073e9SAndroid Build Coastguard Worker            in_t, size=(osize, ), recompute_scale_factor=False, mode="nearest-exact"
9348*da0073e9SAndroid Build Coastguard Worker        )
9349*da0073e9SAndroid Build Coastguard Worker        # compute expected output as scikit-image/scipy
9350*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.zeros(osize, dtype=torch.float).unsqueeze(0).unsqueeze(0)
9351*da0073e9SAndroid Build Coastguard Worker        scale = 1.0 * isize / osize
9352*da0073e9SAndroid Build Coastguard Worker        for o in range(osize):
9353*da0073e9SAndroid Build Coastguard Worker            i_f32 = (o + 0.5) * scale
9354*da0073e9SAndroid Build Coastguard Worker            i = int(i_f32)
9355*da0073e9SAndroid Build Coastguard Worker            expected_out[0, 0, o] = in_t[0, 0, i]
9356*da0073e9SAndroid Build Coastguard Worker        expected_out = expected_out.to(device=device)
9357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out)
9358*da0073e9SAndroid Build Coastguard Worker
9359*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
9360*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9361*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ["nearest", "nearest-exact"])
9362*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest2d(self, device, memory_format, mode):
9363*da0073e9SAndroid Build Coastguard Worker        # Forward AD does not support XLA because XLA tensors don't have storage
9364*da0073e9SAndroid Build Coastguard Worker        check_forward_ad = torch.device(device).type != 'xla'
9365*da0073e9SAndroid Build Coastguard Worker
9366*da0073e9SAndroid Build Coastguard Worker        in_t = torch.ones(1, 2, 2, 2, device=device, dtype=torch.double).contiguous(memory_format=memory_format)
9367*da0073e9SAndroid Build Coastguard Worker        in_uint8_t = torch.ones(1, 2, 2, 2, dtype=torch.uint8, device=device).contiguous(memory_format=memory_format)
9368*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
9369*da0073e9SAndroid Build Coastguard Worker            out_t = F.interpolate(in_t, size=4, mode=mode)
9370*da0073e9SAndroid Build Coastguard Worker            out_uint8_t = F.interpolate(in_uint8_t, size=4, mode=mode)
9371*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
9372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(1, 2, 4, 4, device=device, dtype=torch.double), out_t)
9373*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(1, 2, 4, 4, dtype=torch.uint8, device=device), out_uint8_t)
9374*da0073e9SAndroid Build Coastguard Worker        # Assert that memory format is carried through to the output
9375*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9376*da0073e9SAndroid Build Coastguard Worker
9377*da0073e9SAndroid Build Coastguard Worker        # test forward when input's height is not same as width
9378*da0073e9SAndroid Build Coastguard Worker        in_t = torch.ones(1, 2, 2, 1, device=device, dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9379*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(in_t, size=(4, 2), mode=mode)
9380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ones(1, 2, 4, 2, device=device, dtype=torch.double), out_t)
9381*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9382*da0073e9SAndroid Build Coastguard Worker
9383*da0073e9SAndroid Build Coastguard Worker        out_t.backward(torch.randn_like(out_t))
9384*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format))
9385*da0073e9SAndroid Build Coastguard Worker
9386*da0073e9SAndroid Build Coastguard Worker        # test backward when input's height is not same as width
9387*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(
9388*da0073e9SAndroid Build Coastguard Worker            1, 2, 2, 1, requires_grad=True, device=device,
9389*da0073e9SAndroid Build Coastguard Worker            dtype=torch.double).contiguous(memory_format=memory_format)
9390*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: F.interpolate(x, size=(4, 2), mode=mode), [input], check_forward_ad=check_forward_ad)
9391*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: F.interpolate(x, size=(4, 2), mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9392*da0073e9SAndroid Build Coastguard Worker
9393*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(
9394*da0073e9SAndroid Build Coastguard Worker            1, 2, 2, 2, requires_grad=True, device=device,
9395*da0073e9SAndroid Build Coastguard Worker            dtype=torch.double).contiguous(memory_format=memory_format)
9396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
9397*da0073e9SAndroid Build Coastguard Worker            F.interpolate(input, 4, mode=mode),
9398*da0073e9SAndroid Build Coastguard Worker            F.interpolate(input, scale_factor=2, mode=mode))
9399*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad)
9400*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9401*da0073e9SAndroid Build Coastguard Worker
9402*da0073e9SAndroid Build Coastguard Worker        # Assert that cpu and cuda handle channels_last memory format in the same way
9403*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/54590
9404*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cuda':
9405*da0073e9SAndroid Build Coastguard Worker            for shapes, scale_factor in product([
9406*da0073e9SAndroid Build Coastguard Worker                (2, 2, 3, 4), (2, 3, 4, 5), (3, 1, 2, 2), (1, 5, 3, 2)
9407*da0073e9SAndroid Build Coastguard Worker            ], [0.5, 1.5, 2]):
9408*da0073e9SAndroid Build Coastguard Worker                a_cuda = torch.randn(
9409*da0073e9SAndroid Build Coastguard Worker                    *shapes, device=device,
9410*da0073e9SAndroid Build Coastguard Worker                    dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9411*da0073e9SAndroid Build Coastguard Worker                a_cpu = a_cuda.detach().cpu().requires_grad_()
9412*da0073e9SAndroid Build Coastguard Worker
9413*da0073e9SAndroid Build Coastguard Worker                out_cuda = F.interpolate(a_cuda, scale_factor=scale_factor, mode=mode)
9414*da0073e9SAndroid Build Coastguard Worker                out_cpu = F.interpolate(a_cpu, scale_factor=scale_factor, mode=mode)
9415*da0073e9SAndroid Build Coastguard Worker
9416*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_cpu.cuda(), out_cuda)
9417*da0073e9SAndroid Build Coastguard Worker
9418*da0073e9SAndroid Build Coastguard Worker                g_cuda = torch.randn_like(out_cuda)
9419*da0073e9SAndroid Build Coastguard Worker                g_cpu = g_cuda.cpu()
9420*da0073e9SAndroid Build Coastguard Worker
9421*da0073e9SAndroid Build Coastguard Worker                out_cuda.backward(g_cuda)
9422*da0073e9SAndroid Build Coastguard Worker                out_cpu.backward(g_cpu)
9423*da0073e9SAndroid Build Coastguard Worker
9424*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a_cuda.grad, a_cpu.grad)
9425*da0073e9SAndroid Build Coastguard Worker
9426*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9427*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9428*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest2d_correctness(self, device, memory_format, isize, osize):
9429*da0073e9SAndroid Build Coastguard Worker        # Here we check if output matches OpenCV's INTER_NEAREST-like result
9430*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(isize * isize, dtype=torch.float, device=device).reshape(1, 1, isize, isize)
9431*da0073e9SAndroid Build Coastguard Worker        in_t = in_t.contiguous(memory_format=memory_format)
9432*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(
9433*da0073e9SAndroid Build Coastguard Worker            in_t, size=(osize, osize), recompute_scale_factor=False, mode="nearest"
9434*da0073e9SAndroid Build Coastguard Worker        )
9435*da0073e9SAndroid Build Coastguard Worker        # compute expected output as OpenCV
9436*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.zeros(1, 1, osize, osize, dtype=torch.float)
9437*da0073e9SAndroid Build Coastguard Worker        scale = 1.0 * isize / osize
9438*da0073e9SAndroid Build Coastguard Worker        for o1 in range(osize):
9439*da0073e9SAndroid Build Coastguard Worker            i1_f32 = o1 * scale
9440*da0073e9SAndroid Build Coastguard Worker            i1 = int(i1_f32)
9441*da0073e9SAndroid Build Coastguard Worker            for o2 in range(osize):
9442*da0073e9SAndroid Build Coastguard Worker                i2_f32 = o2 * scale
9443*da0073e9SAndroid Build Coastguard Worker                i2 = int(i2_f32)
9444*da0073e9SAndroid Build Coastguard Worker                expected_out[0, 0, o1, o2] = in_t[0, 0, i1, i2]
9445*da0073e9SAndroid Build Coastguard Worker        expected_out = expected_out.to(device=device)
9446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out)
9447*da0073e9SAndroid Build Coastguard Worker
9448*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # Partially passes https://github.com/pytorch/pytorch/issues/134430
9449*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9450*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9451*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearestExact2d_correctness(self, device, memory_format, isize, osize):
9452*da0073e9SAndroid Build Coastguard Worker        # Here we check if output matches Scikit-Image/Scipy-like result
9453*da0073e9SAndroid Build Coastguard Worker        # Checks https://github.com/pytorch/pytorch/issues/34808
9454*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(isize * isize, dtype=torch.float, device=device).reshape(1, 1, isize, isize)
9455*da0073e9SAndroid Build Coastguard Worker        in_t = in_t.contiguous(memory_format=memory_format)
9456*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(
9457*da0073e9SAndroid Build Coastguard Worker            in_t, size=(osize, osize), recompute_scale_factor=False, mode="nearest-exact"
9458*da0073e9SAndroid Build Coastguard Worker        )
9459*da0073e9SAndroid Build Coastguard Worker        # compute expected output as Scikit-Image/Scipy
9460*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.zeros(1, 1, osize, osize, dtype=torch.float)
9461*da0073e9SAndroid Build Coastguard Worker        scale = 1.0 * isize / osize
9462*da0073e9SAndroid Build Coastguard Worker        for o1 in range(osize):
9463*da0073e9SAndroid Build Coastguard Worker            i1_f32 = (o1 + 0.5) * scale
9464*da0073e9SAndroid Build Coastguard Worker            i1 = int(i1_f32)
9465*da0073e9SAndroid Build Coastguard Worker            for o2 in range(osize):
9466*da0073e9SAndroid Build Coastguard Worker                i2_f32 = (o2 + 0.5) * scale
9467*da0073e9SAndroid Build Coastguard Worker                i2 = int(i2_f32)
9468*da0073e9SAndroid Build Coastguard Worker                expected_out[0, 0, o1, o2] = in_t[0, 0, i1, i2]
9469*da0073e9SAndroid Build Coastguard Worker        expected_out = expected_out.to(device=device)
9470*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out)
9471*da0073e9SAndroid Build Coastguard Worker
9472*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
9473*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9474*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ["nearest", "nearest-exact"])
9475*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest3d(self, device, memory_format, mode):
9476*da0073e9SAndroid Build Coastguard Worker        # Forward AD does not support XLA because XLA tensors don't have storage
9477*da0073e9SAndroid Build Coastguard Worker        check_forward_ad = torch.device(device).type != 'xla'
9478*da0073e9SAndroid Build Coastguard Worker
9479*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(size=4, mode=mode)
9480*da0073e9SAndroid Build Coastguard Worker        in_t = torch.ones(1, 2, 2, 2, 2, device=device, dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9481*da0073e9SAndroid Build Coastguard Worker        in_uint8_t = torch.ones(
9482*da0073e9SAndroid Build Coastguard Worker            1, 2, 2, 2, 2, dtype=torch.uint8, device=device
9483*da0073e9SAndroid Build Coastguard Worker        ).contiguous(memory_format=memory_format)
9484*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
9485*da0073e9SAndroid Build Coastguard Worker            out_t = m(in_t)
9486*da0073e9SAndroid Build Coastguard Worker            out_uint8_t = m(in_uint8_t)
9487*da0073e9SAndroid Build Coastguard Worker        expected_output = torch.ones(1, 2, 4, 4, 4, device=device, dtype=torch.double)
9488*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_output, out_t)
9489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_output.to(torch.uint8), out_uint8_t)
9490*da0073e9SAndroid Build Coastguard Worker        # Assert that memory format is carried through to the output
9491*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9492*da0073e9SAndroid Build Coastguard Worker        out_t.backward(torch.randn_like(out_t))
9493*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format))
9494*da0073e9SAndroid Build Coastguard Worker
9495*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(
9496*da0073e9SAndroid Build Coastguard Worker            1, 2, 2, 2, 2, requires_grad=True, device=device, dtype=torch.double
9497*da0073e9SAndroid Build Coastguard Worker        ).contiguous(memory_format=memory_format)
9498*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_forward_ad=check_forward_ad)
9499*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [input], check_fwd_over_rev=check_forward_ad)
9500*da0073e9SAndroid Build Coastguard Worker
9501*da0073e9SAndroid Build Coastguard Worker        # Assert that cpu and cuda handle channels_last memory format in the same way
9502*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/54590
9503*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'cuda':
9504*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(
9505*da0073e9SAndroid Build Coastguard Worker                2, 2, 2, 3, 4, device=device, requires_grad=True, dtype=torch.double
9506*da0073e9SAndroid Build Coastguard Worker            ).contiguous(memory_format=torch.channels_last_3d)
9507*da0073e9SAndroid Build Coastguard Worker            # make the data asymmetric; ensure that cuda/cpu handle channels_last appropriately.
9508*da0073e9SAndroid Build Coastguard Worker            a[1][1][1][2][2] = a[1][1][1][2][3] = 0
9509*da0073e9SAndroid Build Coastguard Worker
9510*da0073e9SAndroid Build Coastguard Worker            out_cuda = torch.nn.functional.interpolate(a, scale_factor=2, mode=mode)
9511*da0073e9SAndroid Build Coastguard Worker            out_cpu = torch.nn.functional.interpolate(a.to('cpu'), scale_factor=2, mode=mode)
9512*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_cpu, out_cuda.to('cpu'))
9513*da0073e9SAndroid Build Coastguard Worker
9514*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a], check_forward_ad=check_forward_ad)
9515*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a], check_fwd_over_rev=check_forward_ad)
9516*da0073e9SAndroid Build Coastguard Worker
9517*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a.to('cuda')], check_forward_ad=check_forward_ad)
9518*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(lambda x: F.interpolate(x, 4, mode=mode), [a.to('cuda')], check_fwd_over_rev=check_forward_ad)
9519*da0073e9SAndroid Build Coastguard Worker
9520*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9521*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9522*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest3d_correctness(self, device, memory_format, isize, osize):
9523*da0073e9SAndroid Build Coastguard Worker        # Here we check if output matches OpenCV's INTER_NEAREST-like result
9524*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(isize * isize * isize, dtype=torch.float, device=device)
9525*da0073e9SAndroid Build Coastguard Worker        in_t = in_t.reshape(1, 1, isize, isize, isize)
9526*da0073e9SAndroid Build Coastguard Worker        in_t = in_t.contiguous(memory_format=memory_format)
9527*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(
9528*da0073e9SAndroid Build Coastguard Worker            in_t, size=(osize, osize, osize), recompute_scale_factor=False, mode="nearest"
9529*da0073e9SAndroid Build Coastguard Worker        )
9530*da0073e9SAndroid Build Coastguard Worker        # compute expected output as OpenCV
9531*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.zeros(1, 1, osize, osize, osize, dtype=torch.float)
9532*da0073e9SAndroid Build Coastguard Worker        scale = 1.0 * isize / osize
9533*da0073e9SAndroid Build Coastguard Worker        for o1 in range(osize):
9534*da0073e9SAndroid Build Coastguard Worker            i1_f32 = o1 * scale
9535*da0073e9SAndroid Build Coastguard Worker            i1 = int(i1_f32)
9536*da0073e9SAndroid Build Coastguard Worker            for o2 in range(osize):
9537*da0073e9SAndroid Build Coastguard Worker                i2_f32 = o2 * scale
9538*da0073e9SAndroid Build Coastguard Worker                i2 = int(i2_f32)
9539*da0073e9SAndroid Build Coastguard Worker                for o3 in range(osize):
9540*da0073e9SAndroid Build Coastguard Worker                    i3_f32 = o3 * scale
9541*da0073e9SAndroid Build Coastguard Worker                    i3 = int(i3_f32)
9542*da0073e9SAndroid Build Coastguard Worker                    expected_out[0, 0, o1, o2, o3] = in_t[0, 0, i1, i2, i3]
9543*da0073e9SAndroid Build Coastguard Worker        expected_out = expected_out.to(device=device)
9544*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out)
9545*da0073e9SAndroid Build Coastguard Worker
9546*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764
9547*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9548*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("isize, osize", [(20, 11), (10, 15)])
9549*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize):
9550*da0073e9SAndroid Build Coastguard Worker        # Here we check if output matches Scikit-Image/Scipy-like result
9551*da0073e9SAndroid Build Coastguard Worker        # Checks https://github.com/pytorch/pytorch/issues/34808
9552*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(isize * isize * isize, dtype=torch.float, device=device)
9553*da0073e9SAndroid Build Coastguard Worker        in_t = in_t.reshape(1, 1, isize, isize, isize)
9554*da0073e9SAndroid Build Coastguard Worker        in_t = in_t.contiguous(memory_format=memory_format)
9555*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(
9556*da0073e9SAndroid Build Coastguard Worker            in_t, size=(osize, osize, osize), recompute_scale_factor=False, mode="nearest-exact"
9557*da0073e9SAndroid Build Coastguard Worker        )
9558*da0073e9SAndroid Build Coastguard Worker        # compute expected output as Scikit-Image/Scipy
9559*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.zeros(1, 1, osize, osize, osize, dtype=torch.float)
9560*da0073e9SAndroid Build Coastguard Worker        scale = 1.0 * isize / osize
9561*da0073e9SAndroid Build Coastguard Worker        for o1 in range(osize):
9562*da0073e9SAndroid Build Coastguard Worker            i1_f32 = (o1 + 0.5) * scale
9563*da0073e9SAndroid Build Coastguard Worker            i1 = int(i1_f32)
9564*da0073e9SAndroid Build Coastguard Worker            for o2 in range(osize):
9565*da0073e9SAndroid Build Coastguard Worker                i2_f32 = (o2 + 0.5) * scale
9566*da0073e9SAndroid Build Coastguard Worker                i2 = int(i2_f32)
9567*da0073e9SAndroid Build Coastguard Worker                for o3 in range(osize):
9568*da0073e9SAndroid Build Coastguard Worker                    i3_f32 = (o3 + 0.5) * scale
9569*da0073e9SAndroid Build Coastguard Worker                    i3 = int(i3_f32)
9570*da0073e9SAndroid Build Coastguard Worker                    expected_out[0, 0, o1, o2, o3] = in_t[0, 0, i1, i2, i3]
9571*da0073e9SAndroid Build Coastguard Worker        expected_out = expected_out.to(device=device)
9572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out)
9573*da0073e9SAndroid Build Coastguard Worker
9574*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("antialias", [True, False])
9575*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("align_corners", [True, False])
9576*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ["bilinear", "bicubic"])
9577*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9578*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
9579*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format):
9580*da0073e9SAndroid Build Coastguard Worker        # Forward AD does not support XLA because XLA tensors don't have storage
9581*da0073e9SAndroid Build Coastguard Worker        check_forward_ad = torch.device(device).type != 'xla'
9582*da0073e9SAndroid Build Coastguard Worker
9583*da0073e9SAndroid Build Coastguard Worker        kwargs = dict(mode=mode, align_corners=align_corners, antialias=antialias)
9584*da0073e9SAndroid Build Coastguard Worker        # test float scale factor up & downsampling
9585*da0073e9SAndroid Build Coastguard Worker        for scale_factor in [0.5, 1.5, 2]:
9586*da0073e9SAndroid Build Coastguard Worker            in_t = torch.ones(
9587*da0073e9SAndroid Build Coastguard Worker                2, 3, 8, 8, device=device,
9588*da0073e9SAndroid Build Coastguard Worker                dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9589*da0073e9SAndroid Build Coastguard Worker            out_size = int(math.floor(in_t.shape[-1] * scale_factor))
9590*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
9591*da0073e9SAndroid Build Coastguard Worker                out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
9592*da0073e9SAndroid Build Coastguard Worker            expected_out = torch.ones(2, 3, out_size, out_size, device=device, dtype=torch.double)
9593*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_out, out_t)
9594*da0073e9SAndroid Build Coastguard Worker            # Assert that memory format is carried through to the output
9595*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9596*da0073e9SAndroid Build Coastguard Worker            out_t.backward(torch.randn_like(out_t))
9597*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format))
9598*da0073e9SAndroid Build Coastguard Worker
9599*da0073e9SAndroid Build Coastguard Worker            if torch.device(device).type == 'cuda':
9600*da0073e9SAndroid Build Coastguard Worker                # Bilinear backward is nondeterministic because of atomicAdd usage
9601*da0073e9SAndroid Build Coastguard Worker                nondet_tol = 1e-5
9602*da0073e9SAndroid Build Coastguard Worker            else:
9603*da0073e9SAndroid Build Coastguard Worker                nondet_tol = 0.0
9604*da0073e9SAndroid Build Coastguard Worker
9605*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(
9606*da0073e9SAndroid Build Coastguard Worker                2, 3, 8, 8, device=device,
9607*da0073e9SAndroid Build Coastguard Worker                dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_()
9608*da0073e9SAndroid Build Coastguard Worker            gradcheck(
9609*da0073e9SAndroid Build Coastguard Worker                lambda x: F.interpolate(x, out_size, **kwargs),
9610*da0073e9SAndroid Build Coastguard Worker                [input],
9611*da0073e9SAndroid Build Coastguard Worker                check_forward_ad=check_forward_ad, nondet_tol=nondet_tol
9612*da0073e9SAndroid Build Coastguard Worker            )
9613*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(
9614*da0073e9SAndroid Build Coastguard Worker                lambda x: F.interpolate(x, out_size, **kwargs),
9615*da0073e9SAndroid Build Coastguard Worker                [input],
9616*da0073e9SAndroid Build Coastguard Worker                check_fwd_over_rev=check_forward_ad, nondet_tol=nondet_tol
9617*da0073e9SAndroid Build Coastguard Worker            )
9618*da0073e9SAndroid Build Coastguard Worker
9619*da0073e9SAndroid Build Coastguard Worker            # Assert that cpu and cuda give same results
9620*da0073e9SAndroid Build Coastguard Worker            if torch.device(device).type == 'cuda':
9621*da0073e9SAndroid Build Coastguard Worker                for shapes in [
9622*da0073e9SAndroid Build Coastguard Worker                    (2, 2, 3, 4), (2, 3, 4, 5), (3, 1, 2, 2), (1, 5, 3, 2)
9623*da0073e9SAndroid Build Coastguard Worker                ]:
9624*da0073e9SAndroid Build Coastguard Worker                    a_cuda = torch.randn(
9625*da0073e9SAndroid Build Coastguard Worker                        *shapes, device=device, dtype=torch.double
9626*da0073e9SAndroid Build Coastguard Worker                    ).contiguous(memory_format=memory_format).requires_grad_()
9627*da0073e9SAndroid Build Coastguard Worker                    a_cpu = a_cuda.detach().cpu().requires_grad_()
9628*da0073e9SAndroid Build Coastguard Worker
9629*da0073e9SAndroid Build Coastguard Worker                    with warnings.catch_warnings(record=True):
9630*da0073e9SAndroid Build Coastguard Worker                        out_cuda = F.interpolate(a_cuda, scale_factor=scale_factor, **kwargs)
9631*da0073e9SAndroid Build Coastguard Worker                        out_cpu = F.interpolate(a_cpu, scale_factor=scale_factor, **kwargs)
9632*da0073e9SAndroid Build Coastguard Worker
9633*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_cpu, out_cuda.cpu())
9634*da0073e9SAndroid Build Coastguard Worker
9635*da0073e9SAndroid Build Coastguard Worker                    g_cuda = torch.randn_like(out_cuda)
9636*da0073e9SAndroid Build Coastguard Worker                    g_cpu = g_cuda.cpu()
9637*da0073e9SAndroid Build Coastguard Worker
9638*da0073e9SAndroid Build Coastguard Worker                    out_cuda.backward(g_cuda)
9639*da0073e9SAndroid Build Coastguard Worker                    out_cpu.backward(g_cpu)
9640*da0073e9SAndroid Build Coastguard Worker
9641*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a_cuda.grad, a_cpu.grad)
9642*da0073e9SAndroid Build Coastguard Worker
9643*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("antialias", [True, False])
9644*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("num_channels", [3, 5])
9645*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"])
9646*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("dtype", integral_types() + floating_types())
9647*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
9648*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype):
9649*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device)
9650*da0073e9SAndroid Build Coastguard Worker
9651*da0073e9SAndroid Build Coastguard Worker        should_raise_runtime_error = True
9652*da0073e9SAndroid Build Coastguard Worker
9653*da0073e9SAndroid Build Coastguard Worker        if "nearest" in mode:
9654*da0073e9SAndroid Build Coastguard Worker            if antialias:
9655*da0073e9SAndroid Build Coastguard Worker                raise SkipTest("Nearest mode does not have antialiasing")
9656*da0073e9SAndroid Build Coastguard Worker            if dtype in (torch.uint8, ) + floating_types():
9657*da0073e9SAndroid Build Coastguard Worker                should_raise_runtime_error = False
9658*da0073e9SAndroid Build Coastguard Worker
9659*da0073e9SAndroid Build Coastguard Worker        elif mode in ("bilinear", "bicubic"):
9660*da0073e9SAndroid Build Coastguard Worker            if dtype in floating_types() or (device == "cpu" and dtype == torch.uint8):
9661*da0073e9SAndroid Build Coastguard Worker                should_raise_runtime_error = False
9662*da0073e9SAndroid Build Coastguard Worker
9663*da0073e9SAndroid Build Coastguard Worker        if should_raise_runtime_error:
9664*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "not implemented for"):
9665*da0073e9SAndroid Build Coastguard Worker                F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
9666*da0073e9SAndroid Build Coastguard Worker        else:
9667*da0073e9SAndroid Build Coastguard Worker            _ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
9668*da0073e9SAndroid Build Coastguard Worker
9669*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::_upsample_bilinear2d_aa.out https://github.com/pytorch/pytorch/issues/77764
9670*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9671*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
9672*da0073e9SAndroid Build Coastguard Worker        # NOTE: We expand the batch dim such that `b*c` is above the maximum
9673*da0073e9SAndroid Build Coastguard Worker        # size of CUDA grid z-dimension (2**16)
9674*da0073e9SAndroid Build Coastguard Worker        shape = [23000, 3, 8, 8]
9675*da0073e9SAndroid Build Coastguard Worker        t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, *shape[1:])
9676*da0073e9SAndroid Build Coastguard Worker        t_in = t_in.expand(shape)
9677*da0073e9SAndroid Build Coastguard Worker        t_in = t_in.contiguous(memory_format=memory_format)
9678*da0073e9SAndroid Build Coastguard Worker        # This expected result is obtain using PIL.Image.resize
9679*da0073e9SAndroid Build Coastguard Worker        # for c in range(3):
9680*da0073e9SAndroid Build Coastguard Worker        #   a_in = t_in.numpy()[0, c, ...]
9681*da0073e9SAndroid Build Coastguard Worker        #   pil_in = Image.fromarray(a_in)
9682*da0073e9SAndroid Build Coastguard Worker        #   pil_out = pil_in.resize((2, 2), resample=Image.LINEAR)
9683*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.tensor([
9684*da0073e9SAndroid Build Coastguard Worker            17.035713, 20.25, 42.75, 45.964287, 81.03572, 84.25,
9685*da0073e9SAndroid Build Coastguard Worker            106.75, 109.96428, 145.0357, 148.25, 170.75, 173.9643
9686*da0073e9SAndroid Build Coastguard Worker        ], device=device, dtype=t_in.dtype).reshape(1, 3, 2, 2)
9687*da0073e9SAndroid Build Coastguard Worker        t_out = F.interpolate(t_in, size=(2, 2), mode="bilinear", align_corners=False, antialias=True)
9688*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_out.expand([*shape[:2], 2, 2]), t_out)
9689*da0073e9SAndroid Build Coastguard Worker
9690*da0073e9SAndroid Build Coastguard Worker    # Partially passes. NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764
9691*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
9692*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9693*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("mode", ["bilinear", "bicubic"])
9694*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("antialias", [True, False])
9695*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("align_corners", [True, False])
9696*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("num_channels", [3, 5])
9697*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("output_size", [32, 600])
9698*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("check_as_unsqueezed_3d_tensor", [True, False])
9699*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("non_contig", [False, "sliced", "restrided"])
9700*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("batch_size", [1, 5])
9701*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBiMode2d_consistency(
9702*da0073e9SAndroid Build Coastguard Worker        self,
9703*da0073e9SAndroid Build Coastguard Worker        device,
9704*da0073e9SAndroid Build Coastguard Worker        memory_format,
9705*da0073e9SAndroid Build Coastguard Worker        mode,
9706*da0073e9SAndroid Build Coastguard Worker        antialias,
9707*da0073e9SAndroid Build Coastguard Worker        align_corners,
9708*da0073e9SAndroid Build Coastguard Worker        num_channels,
9709*da0073e9SAndroid Build Coastguard Worker        output_size,
9710*da0073e9SAndroid Build Coastguard Worker        check_as_unsqueezed_3d_tensor,
9711*da0073e9SAndroid Build Coastguard Worker        non_contig,
9712*da0073e9SAndroid Build Coastguard Worker        batch_size,
9713*da0073e9SAndroid Build Coastguard Worker    ):
9714*da0073e9SAndroid Build Coastguard Worker        # Check output value consistency between resized_input_uint8 and resized input_float
9715*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == "cuda":
9716*da0073e9SAndroid Build Coastguard Worker            raise SkipTest("CUDA implementation is not yet supporting uint8")
9717*da0073e9SAndroid Build Coastguard Worker
9718*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(0)
9719*da0073e9SAndroid Build Coastguard Worker
9720*da0073e9SAndroid Build Coastguard Worker        # - input range is set to [30, 220] for bicubic mode, because the bicubic kernel may create
9721*da0073e9SAndroid Build Coastguard Worker        #   [intermediate] values outside of the [0, 255] range, which need
9722*da0073e9SAndroid Build Coastguard Worker        #   to be clipped in uint8 path, but not in float path. This isn't
9723*da0073e9SAndroid Build Coastguard Worker        #   an issue with bilinear kernel.
9724*da0073e9SAndroid Build Coastguard Worker        input_range = (30, 220) if mode == "bicubic" else (0, 256)
9725*da0073e9SAndroid Build Coastguard Worker        input_ui8 = torch.randint(*input_range, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device)
9726*da0073e9SAndroid Build Coastguard Worker        input_ui8 = input_ui8.contiguous(memory_format=memory_format)
9727*da0073e9SAndroid Build Coastguard Worker
9728*da0073e9SAndroid Build Coastguard Worker        if non_contig == "sliced":
9729*da0073e9SAndroid Build Coastguard Worker            input_ui8 = input_ui8[:, :, 10:-10, 10:-10]
9730*da0073e9SAndroid Build Coastguard Worker        elif non_contig == "restrided":
9731*da0073e9SAndroid Build Coastguard Worker            input_ui8 = input_ui8[:, :, ::2, ::2]
9732*da0073e9SAndroid Build Coastguard Worker
9733*da0073e9SAndroid Build Coastguard Worker        if batch_size == 1 and check_as_unsqueezed_3d_tensor:
9734*da0073e9SAndroid Build Coastguard Worker            input_ui8 = input_ui8[0, ...]
9735*da0073e9SAndroid Build Coastguard Worker            input_ui8 = input_ui8[None, ...]
9736*da0073e9SAndroid Build Coastguard Worker
9737*da0073e9SAndroid Build Coastguard Worker        input_f32 = input_ui8.float()
9738*da0073e9SAndroid Build Coastguard Worker
9739*da0073e9SAndroid Build Coastguard Worker        output_f32 = F.interpolate(
9740*da0073e9SAndroid Build Coastguard Worker            input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
9741*da0073e9SAndroid Build Coastguard Worker        ).round().clip(0, 255)
9742*da0073e9SAndroid Build Coastguard Worker        output_ui8 = F.interpolate(
9743*da0073e9SAndroid Build Coastguard Worker            input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
9744*da0073e9SAndroid Build Coastguard Worker        )
9745*da0073e9SAndroid Build Coastguard Worker
9746*da0073e9SAndroid Build Coastguard Worker        if non_contig is False:
9747*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format))
9748*da0073e9SAndroid Build Coastguard Worker
9749*da0073e9SAndroid Build Coastguard Worker        # FIXME if-clause shows the current behaviour which is definitely unexpected.
9750*da0073e9SAndroid Build Coastguard Worker        # Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last
9751*da0073e9SAndroid Build Coastguard Worker        # See for more details: https://github.com/pytorch/pytorch/pull/100373
9752*da0073e9SAndroid Build Coastguard Worker        if batch_size == 1 and check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
9753*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_ui8.is_contiguous())
9754*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_f32.is_contiguous())
9755*da0073e9SAndroid Build Coastguard Worker        else:
9756*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format))
9757*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_f32.is_contiguous(memory_format=memory_format))
9758*da0073e9SAndroid Build Coastguard Worker
9759*da0073e9SAndroid Build Coastguard Worker        if mode == "bilinear":
9760*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(output_f32, output_ui8.float(), rtol=0, atol=1)
9761*da0073e9SAndroid Build Coastguard Worker        else:
9762*da0073e9SAndroid Build Coastguard Worker            diff = (output_f32 - output_ui8.float()).abs()
9763*da0073e9SAndroid Build Coastguard Worker            self.assertLess(diff.max(), 15)
9764*da0073e9SAndroid Build Coastguard Worker
9765*da0073e9SAndroid Build Coastguard Worker            threshold = 2
9766*da0073e9SAndroid Build Coastguard Worker            percent = 3
9767*da0073e9SAndroid Build Coastguard Worker            self.assertLess((diff > threshold).float().mean(), percent / 100)
9768*da0073e9SAndroid Build Coastguard Worker
9769*da0073e9SAndroid Build Coastguard Worker            threshold = 5
9770*da0073e9SAndroid Build Coastguard Worker            percent = 1
9771*da0073e9SAndroid Build Coastguard Worker            self.assertLess((diff > threshold).float().mean(), percent / 100)
9772*da0073e9SAndroid Build Coastguard Worker
9773*da0073e9SAndroid Build Coastguard Worker            self.assertLess(diff.mean(), 0.4)
9774*da0073e9SAndroid Build Coastguard Worker
9775*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9776*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("align_corners", [True, False])
9777*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("input_size, output_size", [(399, 437), (403, 377)])
9778*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBiLinear2d_consistency_interp_size_bug(self, device, memory_format, align_corners, input_size, output_size):
9779*da0073e9SAndroid Build Coastguard Worker        # Non-regression test for https://github.com/pytorch/pytorch/pull/101403
9780*da0073e9SAndroid Build Coastguard Worker
9781*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == "cuda":
9782*da0073e9SAndroid Build Coastguard Worker            raise SkipTest("CUDA implementation is not yet supporting uint8")
9783*da0073e9SAndroid Build Coastguard Worker
9784*da0073e9SAndroid Build Coastguard Worker        mode = "bilinear"
9785*da0073e9SAndroid Build Coastguard Worker        input_ui8 = torch.randint(0, 256, size=(1, 3, input_size, input_size), dtype=torch.uint8, device=device)
9786*da0073e9SAndroid Build Coastguard Worker        input_ui8 = input_ui8.contiguous(memory_format=memory_format)
9787*da0073e9SAndroid Build Coastguard Worker        input_f32 = input_ui8.float()
9788*da0073e9SAndroid Build Coastguard Worker
9789*da0073e9SAndroid Build Coastguard Worker        output_f32 = F.interpolate(
9790*da0073e9SAndroid Build Coastguard Worker            input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False
9791*da0073e9SAndroid Build Coastguard Worker        ).round().to(torch.uint8)
9792*da0073e9SAndroid Build Coastguard Worker        output_ui8 = F.interpolate(
9793*da0073e9SAndroid Build Coastguard Worker            input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False
9794*da0073e9SAndroid Build Coastguard Worker        )
9795*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_f32, output_ui8, atol=1, rtol=0)
9796*da0073e9SAndroid Build Coastguard Worker
9797*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764
9798*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBicubic2d_correctness(self, device):
9799*da0073e9SAndroid Build Coastguard Worker        # test output against known input: align_corners=False result must match opencv
9800*da0073e9SAndroid Build Coastguard Worker        in_t = torch.arange(8., device=device).view(1, 2, 2, 2)
9801*da0073e9SAndroid Build Coastguard Worker        expected_out_t = torch.tensor(
9802*da0073e9SAndroid Build Coastguard Worker            [[[[-0.31641, 0.01562, 0.56250, 0.89453],
9803*da0073e9SAndroid Build Coastguard Worker              [0.34766, 0.67969, 1.22656, 1.55859],
9804*da0073e9SAndroid Build Coastguard Worker              [1.44141, 1.77344, 2.32031, 2.65234],
9805*da0073e9SAndroid Build Coastguard Worker              [2.10547, 2.43750, 2.98438, 3.31641]],
9806*da0073e9SAndroid Build Coastguard Worker
9807*da0073e9SAndroid Build Coastguard Worker             [[3.68359, 4.01562, 4.56250, 4.89453],
9808*da0073e9SAndroid Build Coastguard Worker              [4.34766, 4.67969, 5.22656, 5.55859],
9809*da0073e9SAndroid Build Coastguard Worker              [5.44141, 5.77344, 6.32031, 6.65234],
9810*da0073e9SAndroid Build Coastguard Worker              [6.10547, 6.43750, 6.98438, 7.31641]]]], device=device)
9811*da0073e9SAndroid Build Coastguard Worker        out_t = F.interpolate(in_t, scale_factor=2, mode='bicubic', align_corners=False)
9812*da0073e9SAndroid Build Coastguard Worker        torch.set_printoptions(precision=5)
9813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_t, expected_out_t, atol=1e-5, rtol=0)
9814*da0073e9SAndroid Build Coastguard Worker
9815*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::_upsample_bicubic2d_aa.out https://github.com/pytorch/pytorch/issues/77764
9816*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
9817*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format):
9818*da0073e9SAndroid Build Coastguard Worker        t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, 3, 8, 8)
9819*da0073e9SAndroid Build Coastguard Worker        t_in = t_in.contiguous(memory_format=memory_format)
9820*da0073e9SAndroid Build Coastguard Worker        # This expected result is obtain using PIL.Image.resize
9821*da0073e9SAndroid Build Coastguard Worker        # for c in range(3):
9822*da0073e9SAndroid Build Coastguard Worker        #   a_in = t_in.numpy()[0, c, ...]
9823*da0073e9SAndroid Build Coastguard Worker        #   pil_in = Image.fromarray(a_in)
9824*da0073e9SAndroid Build Coastguard Worker        #   pil_out = pil_in.resize((2, 2), resample=Image.BICUBIC)
9825*da0073e9SAndroid Build Coastguard Worker        expected_out = torch.tensor([
9826*da0073e9SAndroid Build Coastguard Worker            15.1205635, 18.760439, 44.23956, 47.879436, 79.12056, 82.76044,
9827*da0073e9SAndroid Build Coastguard Worker            108.23956, 111.87944, 143.12057, 146.76044, 172.23956, 175.87943
9828*da0073e9SAndroid Build Coastguard Worker        ], device=device, dtype=t_in.dtype).reshape(1, 3, 2, 2)
9829*da0073e9SAndroid Build Coastguard Worker        t_out = F.interpolate(t_in, size=(2, 2), mode="bicubic", align_corners=False, antialias=True)
9830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_out, t_out)
9831*da0073e9SAndroid Build Coastguard Worker
9832*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::upsample_trilinear3d.out https://github.com/pytorch/pytorch/issues/77764
9833*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("align_corners", [True, False])
9834*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d])
9835*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingTrilinear3d(self, device, align_corners, memory_format):
9836*da0073e9SAndroid Build Coastguard Worker        kwargs = dict(mode='trilinear', align_corners=align_corners)
9837*da0073e9SAndroid Build Coastguard Worker
9838*da0073e9SAndroid Build Coastguard Worker        # test float scale factor up & downsampling
9839*da0073e9SAndroid Build Coastguard Worker        for scale_factor in [0.5, 1.5, 2]:
9840*da0073e9SAndroid Build Coastguard Worker            m = nn.Upsample(scale_factor=scale_factor, **kwargs)
9841*da0073e9SAndroid Build Coastguard Worker            in_t = torch.ones(1, 2, 4, 4, 4, device=device, dtype=torch.double)
9842*da0073e9SAndroid Build Coastguard Worker            in_t = in_t.contiguous(memory_format=memory_format).requires_grad_()
9843*da0073e9SAndroid Build Coastguard Worker            out_size = int(math.floor(in_t.shape[-1] * scale_factor))
9844*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
9845*da0073e9SAndroid Build Coastguard Worker                out_t = m(in_t)
9846*da0073e9SAndroid Build Coastguard Worker            expected_out = torch.ones(1, 2, out_size, out_size, out_size, device=device, dtype=torch.double)
9847*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_out, out_t)
9848*da0073e9SAndroid Build Coastguard Worker            # Assert that memory format is carried through to the output
9849*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out_t.is_contiguous(memory_format=memory_format))
9850*da0073e9SAndroid Build Coastguard Worker
9851*da0073e9SAndroid Build Coastguard Worker            grad_out = torch.randn_like(out_t).contiguous(memory_format=memory_format)
9852*da0073e9SAndroid Build Coastguard Worker            in_t.grad = None
9853*da0073e9SAndroid Build Coastguard Worker            out_t.backward(grad_out)
9854*da0073e9SAndroid Build Coastguard Worker            grad_in = in_t.grad
9855*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(grad_in.is_contiguous(memory_format=memory_format))
9856*da0073e9SAndroid Build Coastguard Worker
9857*da0073e9SAndroid Build Coastguard Worker            if memory_format == torch.channels_last_3d:
9858*da0073e9SAndroid Build Coastguard Worker                # check if grad inputs CF and CL match
9859*da0073e9SAndroid Build Coastguard Worker                in_t.grad = None
9860*da0073e9SAndroid Build Coastguard Worker                out_t.backward(grad_out.contiguous())
9861*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(in_t.grad, grad_in)
9862*da0073e9SAndroid Build Coastguard Worker
9863*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 2, 4, 4, 4, requires_grad=True, dtype=torch.double)
9864*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
9865*da0073e9SAndroid Build Coastguard Worker                F.interpolate(input, (out_size, out_size, out_size), **kwargs),
9866*da0073e9SAndroid Build Coastguard Worker                F.interpolate(input, scale_factor=scale_factor, **kwargs))
9867*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
9868*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
9869*da0073e9SAndroid Build Coastguard Worker
9870*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
9871*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
9872*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('40GB')
9873*da0073e9SAndroid Build Coastguard Worker    def test_upsampling_64bit_indexing_channels_last(self, device, dtype):
9874*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device)
9875*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.interpolate(x.to(memory_format=torch.channels_last), scale_factor=2, mode='nearest')
9876*da0073e9SAndroid Build Coastguard Worker        out_ref = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
9877*da0073e9SAndroid Build Coastguard Worker        del x
9878*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(out, out_ref))
9879*da0073e9SAndroid Build Coastguard Worker
9880*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
9881*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
9882*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('40GB')
9883*da0073e9SAndroid Build Coastguard Worker    def test_replicatepad_64bit_indexing(self, device, dtype):
9884*da0073e9SAndroid Build Coastguard Worker        conv = torch.nn.Conv1d(128, 128, 3, 1, 1, padding_mode="replicate", device=device, dtype=dtype)
9885*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size=(256 * 448 * 2, 128, 96), dtype=dtype, device=device)
9886*da0073e9SAndroid Build Coastguard Worker        y = conv(x)
9887*da0073e9SAndroid Build Coastguard Worker        torch.mean(y).backward()
9888*da0073e9SAndroid Build Coastguard Worker
9889*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
9890*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
9891*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('40GB')
9892*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingnearest2d_backward_64bit_indexing(self, device, dtype):
9893*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size=(36, 128, 512, 512), device=device, dtype=dtype).requires_grad_()
9894*da0073e9SAndroid Build Coastguard Worker        y = F.interpolate(x, scale_factor=2, mode="nearest")
9895*da0073e9SAndroid Build Coastguard Worker        y.backward(torch.randn_like(y))
9896*da0073e9SAndroid Build Coastguard Worker
9897*da0073e9SAndroid Build Coastguard Worker    def _slow_masked_softmax(self, input, mask):
9898*da0073e9SAndroid Build Coastguard Worker        exp = torch.exp(input)
9899*da0073e9SAndroid Build Coastguard Worker        exp = exp * mask
9900*da0073e9SAndroid Build Coastguard Worker        s = exp.sum(dim=3, keepdim=True).expand(exp.size())
9901*da0073e9SAndroid Build Coastguard Worker        return exp / s
9902*da0073e9SAndroid Build Coastguard Worker
9903*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax_mask_types(self, device):
9904*da0073e9SAndroid Build Coastguard Worker        # Test that mask type 0 (LxL attention mask), mask type 1 (BxL padding mask),
9905*da0073e9SAndroid Build Coastguard Worker        # and mask type 2 (generic BxHxLxL mask) are processed correctly on the
9906*da0073e9SAndroid Build Coastguard Worker        # fast path and the results match explicit slow calculation.
9907*da0073e9SAndroid Build Coastguard Worker        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
9908*da0073e9SAndroid Build Coastguard Worker
9909*da0073e9SAndroid Build Coastguard Worker        for (B, num_heads, L) in sizes:
9910*da0073e9SAndroid Build Coastguard Worker
9911*da0073e9SAndroid Build Coastguard Worker            # mask_type == 0 => attention mask of shape LxL
9912*da0073e9SAndroid Build Coastguard Worker            src_mask_orig = torch.randint(0, 2, (L, L)).bool()
9913*da0073e9SAndroid Build Coastguard Worker            src_mask = src_mask_orig.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool()
9914*da0073e9SAndroid Build Coastguard Worker
9915*da0073e9SAndroid Build Coastguard Worker            # mask_type == 1 => padding mask of shape BxL
9916*da0073e9SAndroid Build Coastguard Worker            src_key_padding_mask_orig = torch.randint(0, 2, (B, L)).bool()
9917*da0073e9SAndroid Build Coastguard Worker            src_key_padding_mask = src_key_padding_mask_orig.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
9918*da0073e9SAndroid Build Coastguard Worker
9919*da0073e9SAndroid Build Coastguard Worker            # mask_type == 2 =>  shape BxHxLxL
9920*da0073e9SAndroid Build Coastguard Worker            generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool()
9921*da0073e9SAndroid Build Coastguard Worker            masks = [(src_mask_orig, src_mask, 0),
9922*da0073e9SAndroid Build Coastguard Worker                     (src_key_padding_mask_orig, src_key_padding_mask, 1),
9923*da0073e9SAndroid Build Coastguard Worker                     (generic_mask, generic_mask, 2)
9924*da0073e9SAndroid Build Coastguard Worker                     ]
9925*da0073e9SAndroid Build Coastguard Worker            for dim in [0, 3]:
9926*da0073e9SAndroid Build Coastguard Worker                for mask_orig, mask, mask_type in masks:
9927*da0073e9SAndroid Build Coastguard Worker                    if (self.device_type == "cuda") and (num_heads % 2) and (mask_type == 1):
9928*da0073e9SAndroid Build Coastguard Worker                        # CUDA path doesn't support padding mask when the number of heads is odd
9929*da0073e9SAndroid Build Coastguard Worker                        continue
9930*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn((B, num_heads, L, L))
9931*da0073e9SAndroid Build Coastguard Worker                    if (self.device_type == "cuda"):
9932*da0073e9SAndroid Build Coastguard Worker                        input = input.cuda()
9933*da0073e9SAndroid Build Coastguard Worker                        mask = mask.cuda()
9934*da0073e9SAndroid Build Coastguard Worker                        mask_orig = mask_orig.cuda()
9935*da0073e9SAndroid Build Coastguard Worker                    native_res = torch._masked_softmax(input, mask_orig, dim, mask_type)
9936*da0073e9SAndroid Build Coastguard Worker                    mask = ~mask
9937*da0073e9SAndroid Build Coastguard Worker
9938*da0073e9SAndroid Build Coastguard Worker                    def slow_masked_softmax(input, mask):
9939*da0073e9SAndroid Build Coastguard Worker                        exp = torch.exp(input)
9940*da0073e9SAndroid Build Coastguard Worker                        exp = exp * mask
9941*da0073e9SAndroid Build Coastguard Worker                        s = exp.sum(dim=dim, keepdim=True).expand(exp.size())
9942*da0073e9SAndroid Build Coastguard Worker                        return exp / s
9943*da0073e9SAndroid Build Coastguard Worker
9944*da0073e9SAndroid Build Coastguard Worker                    pt_res = slow_masked_softmax(input, mask)
9945*da0073e9SAndroid Build Coastguard Worker                    pt_res = torch.nan_to_num(pt_res)
9946*da0073e9SAndroid Build Coastguard Worker
9947*da0073e9SAndroid Build Coastguard Worker                    mask_not = mask.logical_not()
9948*da0073e9SAndroid Build Coastguard Worker                    # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
9949*da0073e9SAndroid Build Coastguard Worker                    # Converts rows with all True's to False
9950*da0073e9SAndroid Build Coastguard Worker                    mask_out = mask_not.all(dim, keepdim=True).expand(mask_not.shape)
9951*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
9952*da0073e9SAndroid Build Coastguard Worker                        pt_res.masked_fill(mask_out, 0),
9953*da0073e9SAndroid Build Coastguard Worker                        native_res.masked_fill(mask_out, 0),
9954*da0073e9SAndroid Build Coastguard Worker                        exact_dtype=True
9955*da0073e9SAndroid Build Coastguard Worker                    )
9956*da0073e9SAndroid Build Coastguard Worker
9957*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
9958*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
9959*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax_devices_parity(self):
9960*da0073e9SAndroid Build Coastguard Worker        # Test that softmax with mask type 0 (LxL attention mask), mask type 1 (BxL padding mask),
9961*da0073e9SAndroid Build Coastguard Worker        # and mask type 2 (BxHxLxL generic mask) gives the same result on CPU and on CUDA.
9962*da0073e9SAndroid Build Coastguard Worker
9963*da0073e9SAndroid Build Coastguard Worker        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
9964*da0073e9SAndroid Build Coastguard Worker        for (B, num_heads, L) in sizes:
9965*da0073e9SAndroid Build Coastguard Worker            # mask_type == 0 => attention mask of shape LxL
9966*da0073e9SAndroid Build Coastguard Worker            src_mask = torch.randint(0, 2, (L, L)).bool()
9967*da0073e9SAndroid Build Coastguard Worker            # mask_type == 1 => padding mask of shape BxL
9968*da0073e9SAndroid Build Coastguard Worker            src_key_padding_mask = torch.randint(0, 2, (B, L)).bool()
9969*da0073e9SAndroid Build Coastguard Worker            # mask_type == 2 => generic mask of shape BxHxLxL
9970*da0073e9SAndroid Build Coastguard Worker            generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool()
9971*da0073e9SAndroid Build Coastguard Worker            masks = [(src_mask, 0), (src_key_padding_mask, 1), (generic_mask, 2)]
9972*da0073e9SAndroid Build Coastguard Worker            input = torch.randn((B, num_heads, L, L))
9973*da0073e9SAndroid Build Coastguard Worker            for dim in [0, 3]:
9974*da0073e9SAndroid Build Coastguard Worker                for mask, mask_type in masks:
9975*da0073e9SAndroid Build Coastguard Worker                    if (num_heads % 2) and (mask_type == 1):
9976*da0073e9SAndroid Build Coastguard Worker                        # CUDA path doesn't support padding mask when the number of heads is odd
9977*da0073e9SAndroid Build Coastguard Worker                        continue
9978*da0073e9SAndroid Build Coastguard Worker
9979*da0073e9SAndroid Build Coastguard Worker                    def softmax_on_device(mask, input, device):
9980*da0073e9SAndroid Build Coastguard Worker                        # Compute softmax on a given device
9981*da0073e9SAndroid Build Coastguard Worker                        input_device = input.to(device)
9982*da0073e9SAndroid Build Coastguard Worker                        mask_device = mask.to(device)
9983*da0073e9SAndroid Build Coastguard Worker                        softmax_res = torch._masked_softmax(input_device, mask_device, dim, mask_type)
9984*da0073e9SAndroid Build Coastguard Worker                        if mask_type == 0:
9985*da0073e9SAndroid Build Coastguard Worker                            mask_expanded = mask_device.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool()
9986*da0073e9SAndroid Build Coastguard Worker                        elif mask_type == 1:
9987*da0073e9SAndroid Build Coastguard Worker                            mask_expanded = mask_device.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
9988*da0073e9SAndroid Build Coastguard Worker                        else:
9989*da0073e9SAndroid Build Coastguard Worker                            mask_expanded = mask_device
9990*da0073e9SAndroid Build Coastguard Worker                        # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
9991*da0073e9SAndroid Build Coastguard Worker                        # Fill rows with all True's with 0
9992*da0073e9SAndroid Build Coastguard Worker                        mask_out = mask_expanded.all(dim, keepdim=True).expand(mask_expanded.shape)
9993*da0073e9SAndroid Build Coastguard Worker                        softmax_res = softmax_res.masked_fill(mask_out, 0)
9994*da0073e9SAndroid Build Coastguard Worker                        return softmax_res
9995*da0073e9SAndroid Build Coastguard Worker
9996*da0073e9SAndroid Build Coastguard Worker                    cpu_res = softmax_on_device(mask, input, "cpu")
9997*da0073e9SAndroid Build Coastguard Worker                    cuda_res = softmax_on_device(mask, input, "cuda")
9998*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cpu_res, cuda_res, exact_dtype=True)
9999*da0073e9SAndroid Build Coastguard Worker
10000*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax(self, device):
10001*da0073e9SAndroid Build Coastguard Worker        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
10002*da0073e9SAndroid Build Coastguard Worker        for (B, num_heads, L) in sizes:
10003*da0073e9SAndroid Build Coastguard Worker            for dim in [0, 3]:
10004*da0073e9SAndroid Build Coastguard Worker                input = torch.randn((B, num_heads, L, L))
10005*da0073e9SAndroid Build Coastguard Worker                mask = torch.randint(0, 2, (B, L))
10006*da0073e9SAndroid Build Coastguard Worker                mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
10007*da0073e9SAndroid Build Coastguard Worker                mask_type = 1   # BxL => src_key_padding_mask
10008*da0073e9SAndroid Build Coastguard Worker                if (self.device_type == "cuda"):
10009*da0073e9SAndroid Build Coastguard Worker                    input = input.cuda()
10010*da0073e9SAndroid Build Coastguard Worker                    mask = mask.cuda()
10011*da0073e9SAndroid Build Coastguard Worker                native_res = torch._masked_softmax(input, mask, dim, mask_type)
10012*da0073e9SAndroid Build Coastguard Worker                mask = ~mask
10013*da0073e9SAndroid Build Coastguard Worker
10014*da0073e9SAndroid Build Coastguard Worker                def slow_masked_softmax(input, mask):
10015*da0073e9SAndroid Build Coastguard Worker                    exp = torch.exp(input)
10016*da0073e9SAndroid Build Coastguard Worker                    exp = exp * mask
10017*da0073e9SAndroid Build Coastguard Worker                    s = exp.sum(dim=dim, keepdim=True).expand(exp.size())
10018*da0073e9SAndroid Build Coastguard Worker                    return exp / s
10019*da0073e9SAndroid Build Coastguard Worker
10020*da0073e9SAndroid Build Coastguard Worker                pt_res = slow_masked_softmax(input, mask)
10021*da0073e9SAndroid Build Coastguard Worker                pt_res = torch.nan_to_num(pt_res)
10022*da0073e9SAndroid Build Coastguard Worker
10023*da0073e9SAndroid Build Coastguard Worker                mask_not = mask.logical_not()
10024*da0073e9SAndroid Build Coastguard Worker                # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
10025*da0073e9SAndroid Build Coastguard Worker                # Converts rows with all True's to False
10026*da0073e9SAndroid Build Coastguard Worker                mask_out = mask_not.all(dim, keepdim=True).expand(mask_not.shape)
10027*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
10028*da0073e9SAndroid Build Coastguard Worker                    pt_res.masked_fill(mask_out, 0),
10029*da0073e9SAndroid Build Coastguard Worker                    native_res.masked_fill(mask_out, 0),
10030*da0073e9SAndroid Build Coastguard Worker                    exact_dtype=True
10031*da0073e9SAndroid Build Coastguard Worker                )
10032*da0073e9SAndroid Build Coastguard Worker
10033*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
10034*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 2e-2, torch.half: 3e-3})
10035*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax_lowp(self, dtype):
10036*da0073e9SAndroid Build Coastguard Worker        sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
10037*da0073e9SAndroid Build Coastguard Worker        for (B, num_heads, L) in sizes:
10038*da0073e9SAndroid Build Coastguard Worker            for dim in [0, 3]:
10039*da0073e9SAndroid Build Coastguard Worker                input_lowp = torch.randn((B, num_heads, L, L), dtype=dtype).requires_grad_()
10040*da0073e9SAndroid Build Coastguard Worker                input_ref = input_lowp.float().detach().requires_grad_()
10041*da0073e9SAndroid Build Coastguard Worker                mask = torch.randint(0, 2, (B, L))
10042*da0073e9SAndroid Build Coastguard Worker                mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool()
10043*da0073e9SAndroid Build Coastguard Worker
10044*da0073e9SAndroid Build Coastguard Worker                for mask_type in [1, 2]:
10045*da0073e9SAndroid Build Coastguard Worker                    res_ref = torch._masked_softmax(input_ref, mask, dim, mask_type)
10046*da0073e9SAndroid Build Coastguard Worker                    res = torch._masked_softmax(input_lowp, mask, dim, mask_type)
10047*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res_ref.to(dtype), res)
10048*da0073e9SAndroid Build Coastguard Worker
10049*da0073e9SAndroid Build Coastguard Worker                    grad_lowp = torch.randn_like(res_ref).to(dtype=dtype)
10050*da0073e9SAndroid Build Coastguard Worker                    grad_ref = grad_lowp.float()
10051*da0073e9SAndroid Build Coastguard Worker
10052*da0073e9SAndroid Build Coastguard Worker                    res_ref.backward(grad_ref)
10053*da0073e9SAndroid Build Coastguard Worker                    res.backward(grad_lowp)
10054*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(input_ref.grad.to(dtype), input_lowp.grad)
10055*da0073e9SAndroid Build Coastguard Worker
10056*da0073e9SAndroid Build Coastguard Worker    def _test_masked_softmax_helper(self, input, dim, mask, mask_type):
10057*da0073e9SAndroid Build Coastguard Worker        input_ref = input.detach().clone().requires_grad_()
10058*da0073e9SAndroid Build Coastguard Worker        result = torch._masked_softmax(input, mask, dim, mask_type)
10059*da0073e9SAndroid Build Coastguard Worker
10060*da0073e9SAndroid Build Coastguard Worker        expected = torch._softmax(input_ref.masked_fill(mask, float('-inf')), dim, False)
10061*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn_like(expected).to(dtype=expected.dtype)
10062*da0073e9SAndroid Build Coastguard Worker
10063*da0073e9SAndroid Build Coastguard Worker        result.backward(grad)
10064*da0073e9SAndroid Build Coastguard Worker        expected.backward(grad)
10065*da0073e9SAndroid Build Coastguard Worker
10066*da0073e9SAndroid Build Coastguard Worker        # Make sure the optional argument works as well
10067*da0073e9SAndroid Build Coastguard Worker        if dim == input.dim() - 1:
10068*da0073e9SAndroid Build Coastguard Worker            input_ref_default = input.detach().clone().requires_grad_()
10069*da0073e9SAndroid Build Coastguard Worker            result_default = torch._masked_softmax(input_ref_default, mask, None, mask_type)
10070*da0073e9SAndroid Build Coastguard Worker            result_default.backward(grad)
10071*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_default)
10072*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input_ref_default.grad)
10073*da0073e9SAndroid Build Coastguard Worker
10074*da0073e9SAndroid Build Coastguard Worker        # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0)
10075*da0073e9SAndroid Build Coastguard Worker        # Converts rows with all True's to False
10076*da0073e9SAndroid Build Coastguard Worker        mask_out = mask.all(dim, keepdim=True).expand(mask.shape)
10077*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.masked_fill(mask_out, 0), expected.masked_fill(mask_out, 0))
10078*da0073e9SAndroid Build Coastguard Worker
10079*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, torch.nan_to_num(input_ref.grad))
10080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, input.grad.masked_fill(mask, 0.0))
10081*da0073e9SAndroid Build Coastguard Worker
10082*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax_grad(self, device):
10083*da0073e9SAndroid Build Coastguard Worker        shapes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)]
10084*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
10085*da0073e9SAndroid Build Coastguard Worker            dims = [0, len(shape) - 1] if len(shape) > 0 else [0]
10086*da0073e9SAndroid Build Coastguard Worker            for dim in dims:
10087*da0073e9SAndroid Build Coastguard Worker                for mask_type in [1, 2]:  # 1 = BxL => src_key_padding_mask
10088*da0073e9SAndroid Build Coastguard Worker                    input = torch.randn(shape, requires_grad=True)
10089*da0073e9SAndroid Build Coastguard Worker                    mask = torch.randint(0, 2, shape).bool()
10090*da0073e9SAndroid Build Coastguard Worker                    if (self.device_type == "cuda"):
10091*da0073e9SAndroid Build Coastguard Worker                        input = input.cuda().detach().requires_grad_()
10092*da0073e9SAndroid Build Coastguard Worker                        mask = mask.cuda()
10093*da0073e9SAndroid Build Coastguard Worker                    self._test_masked_softmax_helper(input, dim, mask, mask_type)
10094*da0073e9SAndroid Build Coastguard Worker
10095*da0073e9SAndroid Build Coastguard Worker    # In this test, the forward pass is expected to produce nan's because when dim=0, we only have unspecified values
10096*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax_forward_with_nans(self, device):
10097*da0073e9SAndroid Build Coastguard Worker        dim = 0
10098*da0073e9SAndroid Build Coastguard Worker        shapes = [(4, 5), (50, 100), (1500, 1200)]
10099*da0073e9SAndroid Build Coastguard Worker        for (x, y) in shapes:
10100*da0073e9SAndroid Build Coastguard Worker            for mask_type in [1, 2]:  # 1 = BxL => src_key_padding_mask
10101*da0073e9SAndroid Build Coastguard Worker                input = torch.randn((x, y), requires_grad=True)
10102*da0073e9SAndroid Build Coastguard Worker                mask = torch.tensor([i % 2 for i in range(y)]).expand((x, y)).bool()
10103*da0073e9SAndroid Build Coastguard Worker                if (self.device_type == "cuda"):
10104*da0073e9SAndroid Build Coastguard Worker                    input = input.cuda().detach().requires_grad_()
10105*da0073e9SAndroid Build Coastguard Worker                    mask = mask.cuda()
10106*da0073e9SAndroid Build Coastguard Worker                self._test_masked_softmax_helper(input, dim, mask, mask_type)
10107*da0073e9SAndroid Build Coastguard Worker
10108*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10109*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax_transformer_layout(self, device):
10110*da0073e9SAndroid Build Coastguard Worker        B = 211
10111*da0073e9SAndroid Build Coastguard Worker        num_heads = 16
10112*da0073e9SAndroid Build Coastguard Worker        L = 42
10113*da0073e9SAndroid Build Coastguard Worker        input = torch.randn((B, num_heads, L, L))
10114*da0073e9SAndroid Build Coastguard Worker        dim = input.dim() - 1
10115*da0073e9SAndroid Build Coastguard Worker        mask = torch.randint(0, 2, (B, L))
10116*da0073e9SAndroid Build Coastguard Worker        mask_type = 1   # BxL => src_key_padding_mask
10117*da0073e9SAndroid Build Coastguard Worker        if (self.device_type == "cuda"):
10118*da0073e9SAndroid Build Coastguard Worker            input = input.cuda()
10119*da0073e9SAndroid Build Coastguard Worker            mask = mask.cuda()
10120*da0073e9SAndroid Build Coastguard Worker        mask = mask.bool()
10121*da0073e9SAndroid Build Coastguard Worker        native_res = torch._masked_softmax(input, mask, dim, mask_type)
10122*da0073e9SAndroid Build Coastguard Worker        mask = mask.reshape(B, 1, 1, L).expand(B, num_heads, L, L)
10123*da0073e9SAndroid Build Coastguard Worker        mask = ~mask
10124*da0073e9SAndroid Build Coastguard Worker        mask = mask.float()
10125*da0073e9SAndroid Build Coastguard Worker
10126*da0073e9SAndroid Build Coastguard Worker        pt_res = self._slow_masked_softmax(input, mask)
10127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pt_res, native_res, exact_dtype=True)
10128*da0073e9SAndroid Build Coastguard Worker
10129*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10130*da0073e9SAndroid Build Coastguard Worker    def test_masked_softmax_TxT_layout(self, device):
10131*da0073e9SAndroid Build Coastguard Worker        B = 211
10132*da0073e9SAndroid Build Coastguard Worker        num_heads = 16
10133*da0073e9SAndroid Build Coastguard Worker        L = 42
10134*da0073e9SAndroid Build Coastguard Worker        input = torch.randn((B, num_heads, L, L))
10135*da0073e9SAndroid Build Coastguard Worker        dim = input.dim() - 1
10136*da0073e9SAndroid Build Coastguard Worker        mask = torch.randint(0, 2, (L, L))
10137*da0073e9SAndroid Build Coastguard Worker        mask_type = 0   # LxL => src_mask
10138*da0073e9SAndroid Build Coastguard Worker        if (self.device_type == "cuda"):
10139*da0073e9SAndroid Build Coastguard Worker            input = input.cuda()
10140*da0073e9SAndroid Build Coastguard Worker            mask = mask.cuda()
10141*da0073e9SAndroid Build Coastguard Worker        mask = mask.bool()
10142*da0073e9SAndroid Build Coastguard Worker        native_res = torch._masked_softmax(input, mask, dim, mask_type)
10143*da0073e9SAndroid Build Coastguard Worker        mask = mask.expand(B, num_heads, L, L)
10144*da0073e9SAndroid Build Coastguard Worker        mask = ~mask
10145*da0073e9SAndroid Build Coastguard Worker        mask = mask.float()
10146*da0073e9SAndroid Build Coastguard Worker
10147*da0073e9SAndroid Build Coastguard Worker        pt_res = self._slow_masked_softmax(input, mask)
10148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pt_res, native_res, exact_dtype=True)
10149*da0073e9SAndroid Build Coastguard Worker
10150*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
10151*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
10152*da0073e9SAndroid Build Coastguard Worker    def test_log_softmax_cpu(self, device, dtype):
10153*da0073e9SAndroid Build Coastguard Worker        for dim in [0, 1]:
10154*da0073e9SAndroid Build Coastguard Worker            inputf = torch.rand(200, 200, device=device, dtype=torch.float, requires_grad=True)
10155*da0073e9SAndroid Build Coastguard Worker            input = inputf.to(dtype).detach().requires_grad_(True)
10156*da0073e9SAndroid Build Coastguard Worker            outf = F.log_softmax(inputf, dim=dim)
10157*da0073e9SAndroid Build Coastguard Worker            out = F.log_softmax(input, dim=dim)
10158*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, outf.to(dtype=dtype), atol=0.1, rtol=0)
10159*da0073e9SAndroid Build Coastguard Worker
10160*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
10161*da0073e9SAndroid Build Coastguard Worker            outf.sum().backward()
10162*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1, rtol=0)
10163*da0073e9SAndroid Build Coastguard Worker
10164*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
10165*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
10166*da0073e9SAndroid Build Coastguard Worker    def test_softmax_cpu(self, device, dtype):
10167*da0073e9SAndroid Build Coastguard Worker        for dim in [0, 1]:
10168*da0073e9SAndroid Build Coastguard Worker            inputf = torch.rand(200, 200, device=device, dtype=torch.float, requires_grad=True)
10169*da0073e9SAndroid Build Coastguard Worker            input = inputf.to(dtype).detach().requires_grad_(True)
10170*da0073e9SAndroid Build Coastguard Worker            outf = F.softmax(inputf, dim=dim)
10171*da0073e9SAndroid Build Coastguard Worker            out = F.softmax(input, dim=dim)
10172*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, outf.to(dtype), atol=1e-3, rtol=0)
10173*da0073e9SAndroid Build Coastguard Worker
10174*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
10175*da0073e9SAndroid Build Coastguard Worker            outf.sum().backward()
10176*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, inputf.grad.to(dtype), atol=1e-3, rtol=0)
10177*da0073e9SAndroid Build Coastguard Worker
10178*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float)
10179*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
10180*da0073e9SAndroid Build Coastguard Worker    def test_softmax_results(self, device, dtype):
10181*da0073e9SAndroid Build Coastguard Worker        # Non-even sizes and non-zero shifts test fallback paths in vectorized kernel
10182*da0073e9SAndroid Build Coastguard Worker        # Note: dim1 > 1024 is needed to exercise the vectorized (non-persistent) path, (16, 30576) is BERT-esque
10183*da0073e9SAndroid Build Coastguard Worker        sizes = [(0, 10), (32, 20), (10, 0), (31, 20), (32, 21), (31, 23), (32, 1536), (31, 2048), (33, 2049), (16, 30576)]
10184*da0073e9SAndroid Build Coastguard Worker        shifts = [(0, 0), (1, 0), (0, 1), (1, 1)]
10185*da0073e9SAndroid Build Coastguard Worker        for fn in [F.softmax, F.log_softmax]:
10186*da0073e9SAndroid Build Coastguard Worker            for size in sizes:
10187*da0073e9SAndroid Build Coastguard Worker                for shift in shifts:
10188*da0073e9SAndroid Build Coastguard Worker                    input = torch.rand(size, device=device, dtype=dtype)
10189*da0073e9SAndroid Build Coastguard Worker                    # Note: With the largest tests we can hit upper limit of fp16 when we
10190*da0073e9SAndroid Build Coastguard Worker                    # sum, so scale the input down to stay in a nicer range.
10191*da0073e9SAndroid Build Coastguard Worker                    if dtype == torch.float16:
10192*da0073e9SAndroid Build Coastguard Worker                        input = input / 100.
10193*da0073e9SAndroid Build Coastguard Worker                    input = input[shift[0]:, shift[1]:]
10194*da0073e9SAndroid Build Coastguard Worker                    # Note; Don't want to bprop back through slice op
10195*da0073e9SAndroid Build Coastguard Worker                    input = input.detach().requires_grad_(True)
10196*da0073e9SAndroid Build Coastguard Worker                    ref_input = input.clone().cpu().detach().requires_grad_(True)
10197*da0073e9SAndroid Build Coastguard Worker                    for dim in [0, 1]:
10198*da0073e9SAndroid Build Coastguard Worker                        ref_output = fn(ref_input, dtype=torch.float, dim=dim)
10199*da0073e9SAndroid Build Coastguard Worker                        output = fn(input, dtype=torch.float, dim=dim)
10200*da0073e9SAndroid Build Coastguard Worker                        grad_output = torch.rand(size, device=device, dtype=dtype)
10201*da0073e9SAndroid Build Coastguard Worker                        grad_output = grad_output[shift[0]:, shift[1]:]
10202*da0073e9SAndroid Build Coastguard Worker                        ref_grad_output = grad_output.clone().cpu().detach()
10203*da0073e9SAndroid Build Coastguard Worker                        grad_input, = torch.autograd.grad(output, input, grad_outputs=(grad_output), create_graph=True)
10204*da0073e9SAndroid Build Coastguard Worker                        ref_grad_input, = torch.autograd.grad(ref_output, ref_input,
10205*da0073e9SAndroid Build Coastguard Worker                                                              grad_outputs=(ref_grad_output), create_graph=True)
10206*da0073e9SAndroid Build Coastguard Worker                        grad_input.sum().backward()
10207*da0073e9SAndroid Build Coastguard Worker                        ref_grad_input.sum().backward()
10208*da0073e9SAndroid Build Coastguard Worker
10209*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(output, ref_output)
10210*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(grad_input, ref_grad_input)
10211*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(input.grad, ref_input.grad)
10212*da0073e9SAndroid Build Coastguard Worker
10213*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10214*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.half)
10215*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("20GB")
10216*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("64GB", "cpu")
10217*da0073e9SAndroid Build Coastguard Worker    def test_warp_softmax_64bit_indexing(self, device, dtype):
10218*da0073e9SAndroid Build Coastguard Worker        def run_test(*shape):
10219*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device="cuda", dtype=torch.float16, requires_grad=True)
10220*da0073e9SAndroid Build Coastguard Worker            y = F.log_softmax(x, dim=-1, dtype=dtype)
10221*da0073e9SAndroid Build Coastguard Worker            y.backward(y)
10222*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
10223*da0073e9SAndroid Build Coastguard Worker                xx = x.cpu().requires_grad_()
10224*da0073e9SAndroid Build Coastguard Worker            yy = F.log_softmax(xx.float(), dim=-1).to(dtype)
10225*da0073e9SAndroid Build Coastguard Worker            yy.backward(yy)
10226*da0073e9SAndroid Build Coastguard Worker            # workaround to reduce memory usage vs. self.assertEqual, see #84944
10227*da0073e9SAndroid Build Coastguard Worker            rtol, atol = torch.testing._comparison.get_tolerances(dtype, rtol=None, atol=None)
10228*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(y.cpu(), yy, rtol=rtol, atol=atol))
10229*da0073e9SAndroid Build Coastguard Worker            # x is half
10230*da0073e9SAndroid Build Coastguard Worker            rtol, _ = torch.testing._comparison.get_tolerances(torch.half, rtol=None, atol=None)
10231*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(x.grad.cpu(), xx.grad, rtol=rtol, atol=1e-3))
10232*da0073e9SAndroid Build Coastguard Worker
10233*da0073e9SAndroid Build Coastguard Worker        run_test(1100000000, 2)  # Illegal memory access https://github.com/pytorch/pytorch/issues/52715
10234*da0073e9SAndroid Build Coastguard Worker        run_test(2200000000, 1)  # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716
10235*da0073e9SAndroid Build Coastguard Worker
10236*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10237*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half)
10238*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("20GB")
10239*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("2GB", "cpu")
10240*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.half: 0.001})
10241*da0073e9SAndroid Build Coastguard Worker    def test_softmax_64bit_indexing(self, device, dtype):
10242*da0073e9SAndroid Build Coastguard Worker        def run_test(*shape):
10243*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(shape, device=device, dtype=dtype, requires_grad=True)
10244*da0073e9SAndroid Build Coastguard Worker            y = F.log_softmax(x, dim=-1, dtype=dtype)
10245*da0073e9SAndroid Build Coastguard Worker            y.backward(y)
10246*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y[0], y[-1])
10247*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad[0], x.grad[-1])
10248*da0073e9SAndroid Build Coastguard Worker
10249*da0073e9SAndroid Build Coastguard Worker        run_test(1024 * 256 + 1, 8192)  # https://github.com/pytorch/pytorch/issues/84144
10250*da0073e9SAndroid Build Coastguard Worker
10251*da0073e9SAndroid Build Coastguard Worker
10252*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
10253*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.half)
10254*da0073e9SAndroid Build Coastguard Worker    def test_log_softmax_big(self, device, dtype):
10255*da0073e9SAndroid Build Coastguard Worker        def _test_helper(shape):
10256*da0073e9SAndroid Build Coastguard Worker            # generate a tensor with big numbers that are exactly representable in dtype
10257*da0073e9SAndroid Build Coastguard Worker            # and are at a constant offset from tensor with small numbers
10258*da0073e9SAndroid Build Coastguard Worker            # the logsoftmax of a small and big tensors should be equal
10259*da0073e9SAndroid Build Coastguard Worker            x_small = torch.randint(100, shape, dtype=dtype, device=device)
10260*da0073e9SAndroid Build Coastguard Worker            offset = 1.5e3 if dtype == torch.half else 1e7
10261*da0073e9SAndroid Build Coastguard Worker            x_big = x_small + offset
10262*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.log_softmax(x_small, -1), F.log_softmax(x_big, -1))
10263*da0073e9SAndroid Build Coastguard Worker        _test_helper((16, 4))
10264*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
10265*da0073e9SAndroid Build Coastguard Worker            # test non-persistent softmax kernel
10266*da0073e9SAndroid Build Coastguard Worker            _test_helper((4, 1536))
10267*da0073e9SAndroid Build Coastguard Worker
10268*da0073e9SAndroid Build Coastguard Worker    def test_save_lstm_compatibility(self, device):
10269*da0073e9SAndroid Build Coastguard Worker        # Test that saving an LSTM in PyTorch 1.7 and older can still be
10270*da0073e9SAndroid Build Coastguard Worker        # loaded in newer versions of PyTorch.
10271*da0073e9SAndroid Build Coastguard Worker        model = nn.LSTM(2, 3)
10272*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(32, 5, 2)
10273*da0073e9SAndroid Build Coastguard Worker        expected = model(x)
10274*da0073e9SAndroid Build Coastguard Worker
10275*da0073e9SAndroid Build Coastguard Worker        # Get a state dict for PyTorch 1.7 LSTM. Before PyTorch 1.8, proj_size
10276*da0073e9SAndroid Build Coastguard Worker        # didn't exist.
10277*da0073e9SAndroid Build Coastguard Worker        assert model.proj_size == 0
10278*da0073e9SAndroid Build Coastguard Worker        state_dict = model.__dict__
10279*da0073e9SAndroid Build Coastguard Worker        del state_dict['proj_size']
10280*da0073e9SAndroid Build Coastguard Worker
10281*da0073e9SAndroid Build Coastguard Worker        # load a model
10282*da0073e9SAndroid Build Coastguard Worker        loaded_model = nn.LSTM(2, 3)
10283*da0073e9SAndroid Build Coastguard Worker        loaded_model.__setstate__(state_dict)
10284*da0073e9SAndroid Build Coastguard Worker        result = loaded_model(x)
10285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
10286*da0073e9SAndroid Build Coastguard Worker
10287*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10288*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
10289*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_large(self, device):
10290*da0073e9SAndroid Build Coastguard Worker        def issue_35202():
10291*da0073e9SAndroid Build Coastguard Worker            input_tensor = torch.rand(1, 1, 480, 640, dtype=torch.float, device=device, requires_grad=True)
10292*da0073e9SAndroid Build Coastguard Worker            coords = torch.tensor([[-10059144, 67680944], [67680944, 67680944]], dtype=torch.float, device=device)
10293*da0073e9SAndroid Build Coastguard Worker            coords = coords.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 1)
10294*da0073e9SAndroid Build Coastguard Worker            result = torch.nn.functional.grid_sample(input_tensor, coords)
10295*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, torch.tensor([[[[0., 0.]]]], dtype=torch.float, device=device))
10296*da0073e9SAndroid Build Coastguard Worker            result.backward(torch.ones_like(result))
10297*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
10298*da0073e9SAndroid Build Coastguard Worker        issue_35202()
10299*da0073e9SAndroid Build Coastguard Worker
10300*da0073e9SAndroid Build Coastguard Worker        def issue_24823_1(dtype):
10301*da0073e9SAndroid Build Coastguard Worker            image = torch.arange(27, 0, -1, dtype=dtype, device=device).view(1, 1, 3, 3, 3)
10302*da0073e9SAndroid Build Coastguard Worker            image.requires_grad_()
10303*da0073e9SAndroid Build Coastguard Worker            grid = torch.nn.functional.affine_grid(
10304*da0073e9SAndroid Build Coastguard Worker                torch.tensor([[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]], dtype=dtype, device=device),
10305*da0073e9SAndroid Build Coastguard Worker                (1, 1, 3, 3, 3))
10306*da0073e9SAndroid Build Coastguard Worker            grid[:, 1, 1, 1, 0] = float('inf')
10307*da0073e9SAndroid Build Coastguard Worker            result = torch.nn.functional.grid_sample(image, grid, padding_mode='zeros')
10308*da0073e9SAndroid Build Coastguard Worker            tol_override = {'atol': 0.005, 'rtol': 0} if dtype == torch.half else {}
10309*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, torch.tensor([[[[[27., 26., 25.], [24., 23., 22.], [21., 20., 19.]],
10310*da0073e9SAndroid Build Coastguard Worker                                                     [[18., 17., 16.], [15., 0., 13.], [12., 11., 10.]],
10311*da0073e9SAndroid Build Coastguard Worker                                                     [[9., 8., 7.], [6., 5., 4.], [3., 2., 1.]]]]],
10312*da0073e9SAndroid Build Coastguard Worker                                                  device=device, dtype=dtype), **tol_override)
10313*da0073e9SAndroid Build Coastguard Worker            result.backward(torch.ones_like(result))
10314*da0073e9SAndroid Build Coastguard Worker            expected_grad = torch.ones_like(image)
10315*da0073e9SAndroid Build Coastguard Worker            expected_grad[0, 0, 1, 1, 1] = 0
10316*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(image.grad, expected_grad, atol=0.005, rtol=0)
10317*da0073e9SAndroid Build Coastguard Worker        issue_24823_1(torch.half)
10318*da0073e9SAndroid Build Coastguard Worker        issue_24823_1(torch.float)
10319*da0073e9SAndroid Build Coastguard Worker        issue_24823_1(torch.double)
10320*da0073e9SAndroid Build Coastguard Worker
10321*da0073e9SAndroid Build Coastguard Worker        def issue_24823_2():
10322*da0073e9SAndroid Build Coastguard Worker            param = torch.tensor([[[-1.0e+20, 0.0, 0.0], [0.0, -1.0e+20, 0.0]]], dtype=torch.float, device=device)
10323*da0073e9SAndroid Build Coastguard Worker            img = torch.zeros((1, 1, 4, 4), dtype=torch.float, device=device, requires_grad=True)
10324*da0073e9SAndroid Build Coastguard Worker            grid = torch.nn.functional.affine_grid(param, img.size())
10325*da0073e9SAndroid Build Coastguard Worker            result = torch.nn.functional.grid_sample(img, grid)
10326*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, torch.zeros(1, 1, 4, 4, device=device, dtype=torch.float))
10327*da0073e9SAndroid Build Coastguard Worker            result.backward(torch.ones_like(result))
10328*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()
10329*da0073e9SAndroid Build Coastguard Worker        issue_24823_2()
10330*da0073e9SAndroid Build Coastguard Worker
10331*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
10332*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest(lambda self, device, dtype:
10333*da0073e9SAndroid Build Coastguard Worker                     # Compute sum of the large tensor sizes:
10334*da0073e9SAndroid Build Coastguard Worker                     # (im.numel() + small_image.numel() + small_image.grad.numel() +
10335*da0073e9SAndroid Build Coastguard Worker                     #   large_view.grad.numel()) * sizeof(dtype)
10336*da0073e9SAndroid Build Coastguard Worker                     32769 * (65536 + 3 * 65536 / 128) *
10337*da0073e9SAndroid Build Coastguard Worker                     torch.tensor([], dtype=dtype).element_size())
10338*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_large_index_2d(self, device, dtype):
10339*da0073e9SAndroid Build Coastguard Worker        # Test 64-bit indexing with grid_sample (gh-41656)
10340*da0073e9SAndroid Build Coastguard Worker        # Try accessing the corners, there should be no segfault
10341*da0073e9SAndroid Build Coastguard Worker        coords = torch.tensor([[[-1., -1.],
10342*da0073e9SAndroid Build Coastguard Worker                                [+1., -1.]],
10343*da0073e9SAndroid Build Coastguard Worker
10344*da0073e9SAndroid Build Coastguard Worker                               [[-1., +1.],
10345*da0073e9SAndroid Build Coastguard Worker                                [+1., +1.]]], device=device, dtype=dtype)
10346*da0073e9SAndroid Build Coastguard Worker        coords = coords.expand(1, 2, 2, 2)
10347*da0073e9SAndroid Build Coastguard Worker        im = torch.zeros([1, 1, 32769, 65536], device=device, dtype=dtype)
10348*da0073e9SAndroid Build Coastguard Worker
10349*da0073e9SAndroid Build Coastguard Worker        # Compare sampling with large strides to the same op on a contiguous tensor
10350*da0073e9SAndroid Build Coastguard Worker        coords = torch.rand(1, 4, 4, 2, device=device, dtype=dtype)
10351*da0073e9SAndroid Build Coastguard Worker        large_view = im[..., 127::128]
10352*da0073e9SAndroid Build Coastguard Worker        small_image = torch.rand_like(large_view)
10353*da0073e9SAndroid Build Coastguard Worker        large_view[...] = small_image
10354*da0073e9SAndroid Build Coastguard Worker        large_view.requires_grad, small_image.requires_grad = True, True
10355*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
10356*da0073e9SAndroid Build Coastguard Worker            sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31,
10357*da0073e9SAndroid Build Coastguard Worker            msg="View must use 64-bit indexing")
10358*da0073e9SAndroid Build Coastguard Worker        for mode, padding_mode, align_corners in itertools.product(
10359*da0073e9SAndroid Build Coastguard Worker                ('nearest', 'bilinear', 'bicubic'), ('zeros', 'border', 'reflection'), (True, False)):
10360*da0073e9SAndroid Build Coastguard Worker            a = F.grid_sample(
10361*da0073e9SAndroid Build Coastguard Worker                small_image, coords, mode=mode,
10362*da0073e9SAndroid Build Coastguard Worker                padding_mode=padding_mode, align_corners=align_corners)
10363*da0073e9SAndroid Build Coastguard Worker            a.sum().backward()
10364*da0073e9SAndroid Build Coastguard Worker
10365*da0073e9SAndroid Build Coastguard Worker            b = F.grid_sample(
10366*da0073e9SAndroid Build Coastguard Worker                large_view, coords, mode=mode,
10367*da0073e9SAndroid Build Coastguard Worker                padding_mode=padding_mode, align_corners=align_corners)
10368*da0073e9SAndroid Build Coastguard Worker            b.sum().backward()
10369*da0073e9SAndroid Build Coastguard Worker
10370*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, b)
10371*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(small_image.grad, large_view.grad)
10372*da0073e9SAndroid Build Coastguard Worker
10373*da0073e9SAndroid Build Coastguard Worker            small_image.grad.zero_()
10374*da0073e9SAndroid Build Coastguard Worker            large_view.grad.zero_()
10375*da0073e9SAndroid Build Coastguard Worker
10376*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
10377*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest(lambda self, device, dtype:
10378*da0073e9SAndroid Build Coastguard Worker                     # Compute sum of the large tensor sizes:
10379*da0073e9SAndroid Build Coastguard Worker                     # (im.numel() + small_image.numel() + small_image.grad.numel() +
10380*da0073e9SAndroid Build Coastguard Worker                     #   large_view.grad.numel()) * sizeof(dtype)
10381*da0073e9SAndroid Build Coastguard Worker                     2 * 32769 * (32768 + 3 * 32768 / 128) *
10382*da0073e9SAndroid Build Coastguard Worker                     torch.tensor([], dtype=dtype).element_size())
10383*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_large_index_3d(self, device, dtype):
10384*da0073e9SAndroid Build Coastguard Worker        # Test 64-bit indexing with grid_sample (gh-41656)
10385*da0073e9SAndroid Build Coastguard Worker        # Try accessing the corners, there should be no segfault
10386*da0073e9SAndroid Build Coastguard Worker        coords = torch.full((1, 2, 2, 2, 3), 1., device=device, dtype=dtype)
10387*da0073e9SAndroid Build Coastguard Worker        im = torch.zeros([1, 1, 2, 32769, 32768], device=device, dtype=dtype)
10388*da0073e9SAndroid Build Coastguard Worker
10389*da0073e9SAndroid Build Coastguard Worker        result = F.grid_sample(im, coords, align_corners=False)
10390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.zeros((1, 1, 2, 2, 2), device=device, dtype=dtype))
10391*da0073e9SAndroid Build Coastguard Worker
10392*da0073e9SAndroid Build Coastguard Worker        # Compare sampling with large strides to the same op on a contiguous tensor
10393*da0073e9SAndroid Build Coastguard Worker        coords = torch.rand(1, 1, 4, 4, 3, device=device, dtype=dtype)
10394*da0073e9SAndroid Build Coastguard Worker        large_view = im[..., 127::128]
10395*da0073e9SAndroid Build Coastguard Worker        small_image = torch.rand_like(large_view)
10396*da0073e9SAndroid Build Coastguard Worker        large_view[...] = small_image
10397*da0073e9SAndroid Build Coastguard Worker        small_image.requires_grad, large_view.requires_grad = True, True
10398*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
10399*da0073e9SAndroid Build Coastguard Worker            sum(i * s for i, s in zip(large_view.size(), large_view.stride())) >= 2 ** 31,
10400*da0073e9SAndroid Build Coastguard Worker            msg="View must use 64-bit indexing")
10401*da0073e9SAndroid Build Coastguard Worker        for mode, padding_mode, align_corners in itertools.product(
10402*da0073e9SAndroid Build Coastguard Worker                ('nearest', 'bilinear'), ('zeros', 'border', 'reflection'), (True, False)):
10403*da0073e9SAndroid Build Coastguard Worker            a = F.grid_sample(
10404*da0073e9SAndroid Build Coastguard Worker                small_image, coords, mode=mode,
10405*da0073e9SAndroid Build Coastguard Worker                padding_mode=padding_mode, align_corners=align_corners)
10406*da0073e9SAndroid Build Coastguard Worker            a.sum().backward()
10407*da0073e9SAndroid Build Coastguard Worker
10408*da0073e9SAndroid Build Coastguard Worker            b = F.grid_sample(
10409*da0073e9SAndroid Build Coastguard Worker                large_view, coords, mode=mode,
10410*da0073e9SAndroid Build Coastguard Worker                padding_mode=padding_mode, align_corners=align_corners)
10411*da0073e9SAndroid Build Coastguard Worker            b.sum().backward()
10412*da0073e9SAndroid Build Coastguard Worker
10413*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, b)
10414*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(small_image.grad, large_view.grad)
10415*da0073e9SAndroid Build Coastguard Worker
10416*da0073e9SAndroid Build Coastguard Worker            small_image.grad.zero_()
10417*da0073e9SAndroid Build Coastguard Worker            large_view.grad.zero_()
10418*da0073e9SAndroid Build Coastguard Worker
10419*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10420*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_half_precision(self):
10421*da0073e9SAndroid Build Coastguard Worker        def helper(shape_in, shape_out, align_corners):
10422*da0073e9SAndroid Build Coastguard Worker            for mode in ('bilinear', 'nearest', 'bicubic'):
10423*da0073e9SAndroid Build Coastguard Worker                if len(shape_in) != 4 and mode == 'bicubic':
10424*da0073e9SAndroid Build Coastguard Worker                    continue
10425*da0073e9SAndroid Build Coastguard Worker                data = torch.randn(shape_in, device='cuda', dtype=torch.half)
10426*da0073e9SAndroid Build Coastguard Worker                grid = torch.rand(shape_out, device='cuda', dtype=torch.half) * 2.0 - 1.0
10427*da0073e9SAndroid Build Coastguard Worker
10428*da0073e9SAndroid Build Coastguard Worker                out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners)
10429*da0073e9SAndroid Build Coastguard Worker                out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros',
10430*da0073e9SAndroid Build Coastguard Worker                                           align_corners=align_corners)
10431*da0073e9SAndroid Build Coastguard Worker
10432*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_half, out_double.half(), msg=f"grid_sample with mode = {mode} doesn't match")
10433*da0073e9SAndroid Build Coastguard Worker
10434*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16), (32, 8, 8, 2), True)
10435*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True)
10436*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16), (32, 8, 8, 2), False)
10437*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False)
10438*da0073e9SAndroid Build Coastguard Worker
10439*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10440*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_bfloat16_precision(self):
10441*da0073e9SAndroid Build Coastguard Worker        def helper(shape_in, shape_out, align_corners):
10442*da0073e9SAndroid Build Coastguard Worker            for mode in ('bilinear', 'nearest', 'bicubic'):
10443*da0073e9SAndroid Build Coastguard Worker                if len(shape_in) != 4 and mode == 'bicubic':
10444*da0073e9SAndroid Build Coastguard Worker                    continue
10445*da0073e9SAndroid Build Coastguard Worker                data = torch.randn(shape_in, device='cuda', dtype=torch.bfloat16)
10446*da0073e9SAndroid Build Coastguard Worker                grid = torch.rand(shape_out, device='cuda', dtype=torch.bfloat16) * 2.0 - 1.0
10447*da0073e9SAndroid Build Coastguard Worker
10448*da0073e9SAndroid Build Coastguard Worker                out_half = F.grid_sample(data, grid, mode=mode, padding_mode='zeros', align_corners=align_corners)
10449*da0073e9SAndroid Build Coastguard Worker                out_double = F.grid_sample(data.double(), grid.double(), mode=mode, padding_mode='zeros',
10450*da0073e9SAndroid Build Coastguard Worker                                           align_corners=align_corners)
10451*da0073e9SAndroid Build Coastguard Worker
10452*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out_half, out_double.bfloat16(), msg=f"grid_sample with mode = {mode} doesn't match")
10453*da0073e9SAndroid Build Coastguard Worker
10454*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16), (32, 8, 8, 2), True)
10455*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), True)
10456*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16), (32, 8, 8, 2), False)
10457*da0073e9SAndroid Build Coastguard Worker        helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False)
10458*da0073e9SAndroid Build Coastguard Worker
10459*da0073e9SAndroid Build Coastguard Worker    def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected):
10460*da0073e9SAndroid Build Coastguard Worker        logits = torch.randn(shape, dtype=torch.float, device=device)
10461*da0073e9SAndroid Build Coastguard Worker        logits = logits.to(dtype)
10462*da0073e9SAndroid Build Coastguard Worker
10463*da0073e9SAndroid Build Coastguard Worker        y_draw = F.gumbel_softmax(logits, hard=True, dim=dim)
10464*da0073e9SAndroid Build Coastguard Worker
10465*da0073e9SAndroid Build Coastguard Worker        # All values positive
10466*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(y_draw.min(), 0)
10467*da0073e9SAndroid Build Coastguard Worker        # Shape unchanged
10468*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y_draw.shape == logits.shape)
10469*da0073e9SAndroid Build Coastguard Worker        # One choice per draw
10470*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_draw.sum(), count_expected, atol=torch.finfo(y_draw.dtype).eps, rtol=0)
10471*da0073e9SAndroid Build Coastguard Worker
10472*da0073e9SAndroid Build Coastguard Worker    def _test_gumbel_softmax_straight_through(self, device, dtype):
10473*da0073e9SAndroid Build Coastguard Worker        num_draws = 100
10474*da0073e9SAndroid Build Coastguard Worker
10475*da0073e9SAndroid Build Coastguard Worker        logits = torch.tensor([[0.2, 0.8, 0.1]], device=device)
10476*da0073e9SAndroid Build Coastguard Worker        logits = logits.reshape([1, 3])
10477*da0073e9SAndroid Build Coastguard Worker        logits = logits.to(dtype).requires_grad_()
10478*da0073e9SAndroid Build Coastguard Worker        probs = logits.softmax(dim=-1)
10479*da0073e9SAndroid Build Coastguard Worker
10480*da0073e9SAndroid Build Coastguard Worker        counts = torch.zeros_like(logits)
10481*da0073e9SAndroid Build Coastguard Worker        for _ in range(num_draws):
10482*da0073e9SAndroid Build Coastguard Worker            y_draw = F.gumbel_softmax(logits, hard=True)
10483*da0073e9SAndroid Build Coastguard Worker            counts = counts + y_draw
10484*da0073e9SAndroid Build Coastguard Worker
10485*da0073e9SAndroid Build Coastguard Worker        # All values positive
10486*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(y_draw.min(), 0)
10487*da0073e9SAndroid Build Coastguard Worker        # Each experiment should result in 1 draw.
10488*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts.sum(), num_draws, atol=torch.finfo(counts.dtype).eps, rtol=0)
10489*da0073e9SAndroid Build Coastguard Worker
10490*da0073e9SAndroid Build Coastguard Worker        # check results is asymptotically as expected.
10491*da0073e9SAndroid Build Coastguard Worker        expected = probs * num_draws
10492*da0073e9SAndroid Build Coastguard Worker        # ~z is approximately N(0,1) for unbiased count
10493*da0073e9SAndroid Build Coastguard Worker        z = (counts - expected) / (expected * (1 - probs)).sqrt()
10494*da0073e9SAndroid Build Coastguard Worker        # A (lazy) approximate 99% two-sided test:
10495*da0073e9SAndroid Build Coastguard Worker        # occurs with prob alpha~>=0.01 if unbiased
10496*da0073e9SAndroid Build Coastguard Worker        self.assertLess(z.abs().max().item(), 2.58)
10497*da0073e9SAndroid Build Coastguard Worker
10498*da0073e9SAndroid Build Coastguard Worker    def _test_gumbel_softmax_grad(self, device, dtype):
10499*da0073e9SAndroid Build Coastguard Worker        # "hard" and "not hard" should propagate same gradient.
10500*da0073e9SAndroid Build Coastguard Worker        logits_soft = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
10501*da0073e9SAndroid Build Coastguard Worker        logits_hard = torch.zeros(10, 10, dtype=dtype, device=device, requires_grad=True)
10502*da0073e9SAndroid Build Coastguard Worker
10503*da0073e9SAndroid Build Coastguard Worker        seed = torch.random.get_rng_state()
10504*da0073e9SAndroid Build Coastguard Worker        y_soft = F.gumbel_softmax(logits_soft, hard=False)
10505*da0073e9SAndroid Build Coastguard Worker        torch.random.set_rng_state(seed)
10506*da0073e9SAndroid Build Coastguard Worker        y_hard = F.gumbel_softmax(logits_hard, hard=True)
10507*da0073e9SAndroid Build Coastguard Worker
10508*da0073e9SAndroid Build Coastguard Worker        y_soft.sum().backward()
10509*da0073e9SAndroid Build Coastguard Worker        y_hard.sum().backward()
10510*da0073e9SAndroid Build Coastguard Worker
10511*da0073e9SAndroid Build Coastguard Worker        # 2eps = 1x addition + 1x subtraction.
10512*da0073e9SAndroid Build Coastguard Worker        tol = 2 * torch.finfo(dtype).eps
10513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(logits_soft.grad, logits_hard.grad, atol=tol, rtol=0)
10514*da0073e9SAndroid Build Coastguard Worker
10515*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
10516*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.float)
10517*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
10518*da0073e9SAndroid Build Coastguard Worker    def test_gumbel_softmax(self, device, dtype):
10519*da0073e9SAndroid Build Coastguard Worker        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=0, count_expected=1)
10520*da0073e9SAndroid Build Coastguard Worker        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=-1, count_expected=1)
10521*da0073e9SAndroid Build Coastguard Worker        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4], dim=1, count_expected=5)
10522*da0073e9SAndroid Build Coastguard Worker        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=1, count_expected=5 * 3)
10523*da0073e9SAndroid Build Coastguard Worker        self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5, 4, 3], dim=-1, count_expected=5 * 4)
10524*da0073e9SAndroid Build Coastguard Worker        self._test_gumbel_softmax_straight_through(device, dtype)
10525*da0073e9SAndroid Build Coastguard Worker        self._test_gumbel_softmax_grad(device, dtype)
10526*da0073e9SAndroid Build Coastguard Worker
10527*da0073e9SAndroid Build Coastguard Worker    def _test_rnn_retain_variables(self, device, dtype):
10528*da0073e9SAndroid Build Coastguard Worker        rnns = [nn.LSTM(10, 20, num_layers=2).to(device, dtype),
10529*da0073e9SAndroid Build Coastguard Worker                nn.GRU(10, 20, num_layers=2).to(device, dtype),
10530*da0073e9SAndroid Build Coastguard Worker                nn.RNN(10, 20, num_layers=2).to(device, dtype)]
10531*da0073e9SAndroid Build Coastguard Worker        for rnn in rnns:
10532*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(5, 6, 10, device=device, dtype=dtype, requires_grad=True)
10533*da0073e9SAndroid Build Coastguard Worker            output = rnn(input)
10534*da0073e9SAndroid Build Coastguard Worker            output[0].sum().backward(retain_graph=True)
10535*da0073e9SAndroid Build Coastguard Worker            grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()]
10536*da0073e9SAndroid Build Coastguard Worker            for _ in range(4):
10537*da0073e9SAndroid Build Coastguard Worker                rnn.zero_grad()
10538*da0073e9SAndroid Build Coastguard Worker                input.grad.data.zero_()
10539*da0073e9SAndroid Build Coastguard Worker                output[0].sum().backward(retain_graph=True)
10540*da0073e9SAndroid Build Coastguard Worker                grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()]
10541*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grads, grads2)
10542*da0073e9SAndroid Build Coastguard Worker
10543*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
10544*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.half, torch.float)
10545*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
10546*da0073e9SAndroid Build Coastguard Worker    def test_rnn_retain_variables(self, device, dtype):
10547*da0073e9SAndroid Build Coastguard Worker        self._test_rnn_retain_variables(device, dtype)
10548*da0073e9SAndroid Build Coastguard Worker
10549*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
10550*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
10551*da0073e9SAndroid Build Coastguard Worker                self._test_rnn_retain_variables(device, dtype)
10552*da0073e9SAndroid Build Coastguard Worker
10553*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10554*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
10555*da0073e9SAndroid Build Coastguard Worker    def test_lstmcell_backward_only_one_output_grad(self, device, dtype):
10556*da0073e9SAndroid Build Coastguard Worker        # checks that undefined gradients doen't hamper the backward
10557*da0073e9SAndroid Build Coastguard Worker        # see #11872
10558*da0073e9SAndroid Build Coastguard Worker        l = torch.nn.LSTMCell(2, 3).to(device).to(dtype=dtype)
10559*da0073e9SAndroid Build Coastguard Worker        s = torch.randn(1, 2, device=device, dtype=dtype, requires_grad=True)
10560*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
10561*da0073e9SAndroid Build Coastguard Worker            out = l(s)[i]
10562*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
10563*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0)
10564*da0073e9SAndroid Build Coastguard Worker
10565*da0073e9SAndroid Build Coastguard Worker    def _test_rnn_mod(self, mod, inp):
10566*da0073e9SAndroid Build Coastguard Worker        def flatten_out(mod, inp):
10567*da0073e9SAndroid Build Coastguard Worker            out = mod(inp)
10568*da0073e9SAndroid Build Coastguard Worker            return tuple([t if isinstance(t, torch.Tensor) else tt for t in out for tt in t])
10569*da0073e9SAndroid Build Coastguard Worker        gradcheckfunc = partial(flatten_out, mod)
10570*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=False):
10571*da0073e9SAndroid Build Coastguard Worker            gradcheck(gradcheckfunc, inp, check_batched_grad=False)
10572*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(gradcheckfunc, inp, check_batched_grad=False)
10573*da0073e9SAndroid Build Coastguard Worker
10574*da0073e9SAndroid Build Coastguard Worker        if inp.is_cuda and not TEST_WITH_ROCM:
10575*da0073e9SAndroid Build Coastguard Worker            # Assert that we have good error message around unsupported CuDNN double backward
10576*da0073e9SAndroid Build Coastguard Worker            # NB: we trigger double backward using .backward() instead of autograd.grad due to
10577*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/37874
10578*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=True):
10579*da0073e9SAndroid Build Coastguard Worker                result = gradcheckfunc(inp)
10580*da0073e9SAndroid Build Coastguard Worker                result[0].sum().backward(create_graph=True)
10581*da0073e9SAndroid Build Coastguard Worker                grad0 = next(mod.parameters()).grad
10582*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError,
10583*da0073e9SAndroid Build Coastguard Worker                                            "please disable the CuDNN backend temporarily"):
10584*da0073e9SAndroid Build Coastguard Worker                    grad0.sum().backward()
10585*da0073e9SAndroid Build Coastguard Worker
10586*da0073e9SAndroid Build Coastguard Worker                # Here we avoid the backward(create_graph=True) memory leak
10587*da0073e9SAndroid Build Coastguard Worker                # described in https://github.com/pytorch/pytorch/issues/7343
10588*da0073e9SAndroid Build Coastguard Worker                for param in mod.parameters():
10589*da0073e9SAndroid Build Coastguard Worker                    param.grad = None
10590*da0073e9SAndroid Build Coastguard Worker                inp.grad = None
10591*da0073e9SAndroid Build Coastguard Worker
10592*da0073e9SAndroid Build Coastguard Worker    # Merge into OpInfo?
10593*da0073e9SAndroid Build Coastguard Worker    @skipMeta  # LSTM cell reuses output which was resized
10594*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10595*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
10596*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_grad_and_gradgrad(self, device, dtype):
10597*da0073e9SAndroid Build Coastguard Worker        hsize = 4
10598*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True)
10599*da0073e9SAndroid Build Coastguard Worker        for bias in [True, False]:
10600*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.LSTM(hsize, hsize, bias=bias).to(device).to(dtype)
10601*da0073e9SAndroid Build Coastguard Worker            self._test_rnn_mod(mod, inp)
10602*da0073e9SAndroid Build Coastguard Worker
10603*da0073e9SAndroid Build Coastguard Worker    @skipMeta  # GRU cell reuses output which was resized
10604*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10605*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
10606*da0073e9SAndroid Build Coastguard Worker    def test_GRU_grad_and_gradgrad(self, device, dtype):
10607*da0073e9SAndroid Build Coastguard Worker        hsize = 4
10608*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True)
10609*da0073e9SAndroid Build Coastguard Worker        for bias in [True, False]:
10610*da0073e9SAndroid Build Coastguard Worker            mod = torch.nn.GRU(hsize, hsize, bias=bias).to(device).to(dtype)
10611*da0073e9SAndroid Build Coastguard Worker            self._test_rnn_mod(mod, inp)
10612*da0073e9SAndroid Build Coastguard Worker
10613*da0073e9SAndroid Build Coastguard Worker    @skipMeta
10614*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.bfloat16)
10615*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
10616*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_differentiable_backward_using_oneDNN(self, dtype):
10617*da0073e9SAndroid Build Coastguard Worker        batch = 10
10618*da0073e9SAndroid Build Coastguard Worker        seq_len = 12
10619*da0073e9SAndroid Build Coastguard Worker        input = 3
10620*da0073e9SAndroid Build Coastguard Worker        Net = nn.LSTM(input, 3, 20, batch_first=True)
10621*da0073e9SAndroid Build Coastguard Worker        import copy
10622*da0073e9SAndroid Build Coastguard Worker        Net_clone = copy.deepcopy(Net)
10623*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(batch, seq_len, input)
10624*da0073e9SAndroid Build Coastguard Worker        x1 = x.clone().requires_grad_(True)
10625*da0073e9SAndroid Build Coastguard Worker        x2 = x.clone().requires_grad_(True)
10626*da0073e9SAndroid Build Coastguard Worker
10627*da0073e9SAndroid Build Coastguard Worker        torch._C._set_mkldnn_enabled(False)
10628*da0073e9SAndroid Build Coastguard Worker        out1, _ = Net(x1)
10629*da0073e9SAndroid Build Coastguard Worker        der_out1 = torch.autograd.grad(out1, x1,
10630*da0073e9SAndroid Build Coastguard Worker                                       grad_outputs=torch.ones_like(out1),
10631*da0073e9SAndroid Build Coastguard Worker                                       retain_graph=True,
10632*da0073e9SAndroid Build Coastguard Worker                                       create_graph=True)[0]
10633*da0073e9SAndroid Build Coastguard Worker        loss1 = der_out1.sum()
10634*da0073e9SAndroid Build Coastguard Worker        loss1.backward(retain_graph=True)
10635*da0073e9SAndroid Build Coastguard Worker
10636*da0073e9SAndroid Build Coastguard Worker        torch._C._set_mkldnn_enabled(True)
10637*da0073e9SAndroid Build Coastguard Worker        out2, _ = Net(x2)
10638*da0073e9SAndroid Build Coastguard Worker        der_out2 = torch.autograd.grad(out2, x2,
10639*da0073e9SAndroid Build Coastguard Worker                                       grad_outputs=torch.ones_like(out2),
10640*da0073e9SAndroid Build Coastguard Worker                                       retain_graph=True,
10641*da0073e9SAndroid Build Coastguard Worker                                       create_graph=True)[0]
10642*da0073e9SAndroid Build Coastguard Worker        loss2 = der_out2.sum()
10643*da0073e9SAndroid Build Coastguard Worker        loss2.backward(retain_graph=True)
10644*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(der_out1, der_out2)
10645*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(x1.grad, x2.grad)
10646*da0073e9SAndroid Build Coastguard Worker
10647*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10648*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest1d_launch_config(self, device):
10649*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=2)
10650*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(2**25, 1, 1, device=device)
10651*da0073e9SAndroid Build Coastguard Worker        out = m(inp)
10652*da0073e9SAndroid Build Coastguard Worker        inp_ref = inp.cpu()
10653*da0073e9SAndroid Build Coastguard Worker        out_ref = m(inp_ref)
10654*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
10655*da0073e9SAndroid Build Coastguard Worker
10656*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10657*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest2d_launch_config(self, device):
10658*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=2)
10659*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(2**25, 1, 1, 1, device=device)
10660*da0073e9SAndroid Build Coastguard Worker        out = m(inp)
10661*da0073e9SAndroid Build Coastguard Worker        inp_ref = inp.cpu()
10662*da0073e9SAndroid Build Coastguard Worker        out_ref = m(inp_ref)
10663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
10664*da0073e9SAndroid Build Coastguard Worker
10665*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10666*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
10667*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest3d_launch_config(self, device):
10668*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=2)
10669*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(2**25, 1, 1, 1, 1, device=device)
10670*da0073e9SAndroid Build Coastguard Worker        out = m(inp)
10671*da0073e9SAndroid Build Coastguard Worker        inp_ref = inp.cpu()
10672*da0073e9SAndroid Build Coastguard Worker        out_ref = m(inp_ref)
10673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
10674*da0073e9SAndroid Build Coastguard Worker
10675*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
10676*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
10677*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10678*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest2d_launch_fail(self, device):
10679*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=2)
10680*da0073e9SAndroid Build Coastguard Worker        # launch grid_y == 2**16 (larger than maximum y-dimension limit 65535)
10681*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(1, 1, 2**15, 2**8, device=device)
10682*da0073e9SAndroid Build Coastguard Worker        out = m(inp)
10683*da0073e9SAndroid Build Coastguard Worker
10684*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10685*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfNotRocm
10686*da0073e9SAndroid Build Coastguard Worker    def test_upsamplingNearest2d_launch_rocm(self, device):
10687*da0073e9SAndroid Build Coastguard Worker        # test_upsamplingNearest2d_launch_fail should run OK on ROCm
10688*da0073e9SAndroid Build Coastguard Worker        m = nn.Upsample(scale_factor=2)
10689*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(1, 1, 2**15, 2**8, device=device)
10690*da0073e9SAndroid Build Coastguard Worker        out = m(inp)
10691*da0073e9SAndroid Build Coastguard Worker
10692*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10693*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfCudnnVersionLessThan(7600)
10694*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_cudnn(self, device):
10695*da0073e9SAndroid Build Coastguard Worker        def _helper(zero_infinity):
10696*da0073e9SAndroid Build Coastguard Worker            target_lengths = [30, 25, 20]
10697*da0073e9SAndroid Build Coastguard Worker            input_lengths = [50, 50, 50]
10698*da0073e9SAndroid Build Coastguard Worker            targets = torch.randint(1, 15, (sum(target_lengths),), dtype=torch.int)
10699*da0073e9SAndroid Build Coastguard Worker            log_probs = torch.randn(50, 3, 15, dtype=torch.float, device=device).log_softmax(2).requires_grad_()
10700*da0073e9SAndroid Build Coastguard Worker
10701*da0073e9SAndroid Build Coastguard Worker            log_probs_ref = log_probs.detach().clone().requires_grad_()
10702*da0073e9SAndroid Build Coastguard Worker
10703*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=True):
10704*da0073e9SAndroid Build Coastguard Worker                res = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, zero_infinity=zero_infinity)
10705*da0073e9SAndroid Build Coastguard Worker                res.backward()
10706*da0073e9SAndroid Build Coastguard Worker
10707*da0073e9SAndroid Build Coastguard Worker            expected = ctcloss_reference(log_probs, targets.cuda(), input_lengths, target_lengths).float()
10708*da0073e9SAndroid Build Coastguard Worker
10709*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
10710*da0073e9SAndroid Build Coastguard Worker                res2 = torch.nn.functional.ctc_loss(log_probs_ref, targets.cuda().long(), input_lengths, target_lengths,
10711*da0073e9SAndroid Build Coastguard Worker                                                    zero_infinity=zero_infinity)
10712*da0073e9SAndroid Build Coastguard Worker                res2.backward()
10713*da0073e9SAndroid Build Coastguard Worker
10714*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, expected)
10715*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res2, res)
10716*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_probs.grad, log_probs_ref.grad)
10717*da0073e9SAndroid Build Coastguard Worker
10718*da0073e9SAndroid Build Coastguard Worker        _helper(zero_infinity=True)
10719*da0073e9SAndroid Build Coastguard Worker        _helper(zero_infinity=False)
10720*da0073e9SAndroid Build Coastguard Worker
10721*da0073e9SAndroid Build Coastguard Worker    def _CTCLoss_gen_losses(self, device, input_length, vocab_size, target_length, reduction, use_module_form):
10722*da0073e9SAndroid Build Coastguard Worker        batch_size = 1
10723*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.randn(input_length, batch_size, vocab_size, dtype=torch.float, device=device) \
10724*da0073e9SAndroid Build Coastguard Worker                         .log_softmax(2).requires_grad_()
10725*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(low=1, high=vocab_size - 1, size=(batch_size, target_length),
10726*da0073e9SAndroid Build Coastguard Worker                                dtype=torch.int, device=device)
10727*da0073e9SAndroid Build Coastguard Worker        input_lengths = batch_size * [input_length]
10728*da0073e9SAndroid Build Coastguard Worker        target_lengths = batch_size * [target_length]
10729*da0073e9SAndroid Build Coastguard Worker
10730*da0073e9SAndroid Build Coastguard Worker        log_probs_no_bd = log_probs.squeeze(1).detach().clone().requires_grad_()
10731*da0073e9SAndroid Build Coastguard Worker        targets_no_bd = targets.squeeze(0).detach().clone()
10732*da0073e9SAndroid Build Coastguard Worker        input_lengths_no_bd = torch.tensor(input_length)
10733*da0073e9SAndroid Build Coastguard Worker        target_lengths_no_bd = torch.tensor(target_length)
10734*da0073e9SAndroid Build Coastguard Worker
10735*da0073e9SAndroid Build Coastguard Worker        # currently only length 2 and 1 right now, but left flexible for additional potential cases
10736*da0073e9SAndroid Build Coastguard Worker        log_probs_refs = [log_probs.detach().clone().requires_grad_() for _ in range(2)]
10737*da0073e9SAndroid Build Coastguard Worker        log_probs_no_bd_refs = [log_probs_no_bd.detach().clone().requires_grad_() for _ in range(1)]
10738*da0073e9SAndroid Build Coastguard Worker
10739*da0073e9SAndroid Build Coastguard Worker        losses = []
10740*da0073e9SAndroid Build Coastguard Worker        losses_no_bd = []
10741*da0073e9SAndroid Build Coastguard Worker
10742*da0073e9SAndroid Build Coastguard Worker        has_cuda = torch.cuda.is_available()
10743*da0073e9SAndroid Build Coastguard Worker        has_cudnn = has_cuda and 'cuda' in device and self.has_cudnn()
10744*da0073e9SAndroid Build Coastguard Worker        # cudnn requires a cpu target
10745*da0073e9SAndroid Build Coastguard Worker        if has_cuda and has_cudnn:
10746*da0073e9SAndroid Build Coastguard Worker            targets = targets.cpu()
10747*da0073e9SAndroid Build Coastguard Worker            targets_no_bd = targets_no_bd.cpu()
10748*da0073e9SAndroid Build Coastguard Worker
10749*da0073e9SAndroid Build Coastguard Worker        ctc_loss = (
10750*da0073e9SAndroid Build Coastguard Worker            nn.CTCLoss(reduction=reduction, zero_infinity=True)
10751*da0073e9SAndroid Build Coastguard Worker            if use_module_form
10752*da0073e9SAndroid Build Coastguard Worker            else partial(torch.nn.functional.ctc_loss, reduction=reduction, zero_infinity=True)
10753*da0073e9SAndroid Build Coastguard Worker        )
10754*da0073e9SAndroid Build Coastguard Worker
10755*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=has_cudnn):
10756*da0073e9SAndroid Build Coastguard Worker            # batched case. log_probs.shape = (T, N, C), targets = (N, S), input_lengths/target_lengths = (N,)
10757*da0073e9SAndroid Build Coastguard Worker            losses.append(ctc_loss(log_probs_refs[0], targets, input_lengths, target_lengths))
10758*da0073e9SAndroid Build Coastguard Worker            # batched case. input.shape = (T, N, C), targets = (S,), input_lengths/target_lengths = (N,)
10759*da0073e9SAndroid Build Coastguard Worker            losses.append(ctc_loss(log_probs_refs[1], targets_no_bd, input_lengths, target_lengths))
10760*da0073e9SAndroid Build Coastguard Worker            # unbatched case. input.shape = (T, C), targets = (S,), input_lengths/target_lengths = (N,)
10761*da0073e9SAndroid Build Coastguard Worker            losses_no_bd.append(ctc_loss(log_probs_no_bd_refs[0], targets_no_bd,
10762*da0073e9SAndroid Build Coastguard Worker                                         input_lengths_no_bd, target_lengths_no_bd))
10763*da0073e9SAndroid Build Coastguard Worker
10764*da0073e9SAndroid Build Coastguard Worker            for loss in losses + losses_no_bd:
10765*da0073e9SAndroid Build Coastguard Worker                loss.backward()
10766*da0073e9SAndroid Build Coastguard Worker
10767*da0073e9SAndroid Build Coastguard Worker        return losses, losses_no_bd, log_probs_refs, log_probs_no_bd_refs
10768*da0073e9SAndroid Build Coastguard Worker
10769*da0073e9SAndroid Build Coastguard Worker    def _assertEqual_list(self, expected, list_to_compare, atol=None, rtol=None):
10770*da0073e9SAndroid Build Coastguard Worker        for ele in list_to_compare:
10771*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, ele, atol=atol, rtol=rtol)
10772*da0073e9SAndroid Build Coastguard Worker
10773*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764
10774*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("reduction", ['none', 'mean', 'sum'])
10775*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("use_module_form", [True, False])
10776*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_no_batch_dim(self, device, reduction, use_module_form):
10777*da0073e9SAndroid Build Coastguard Worker        input_length = 40
10778*da0073e9SAndroid Build Coastguard Worker        vocab_size = 3
10779*da0073e9SAndroid Build Coastguard Worker        target_length = 12
10780*da0073e9SAndroid Build Coastguard Worker
10781*da0073e9SAndroid Build Coastguard Worker        args = self._CTCLoss_gen_losses(device, input_length, vocab_size, target_length, reduction, use_module_form)
10782*da0073e9SAndroid Build Coastguard Worker        losses, losses_no_bd, log_probs_refs, log_probs_no_bd_refs = args
10783*da0073e9SAndroid Build Coastguard Worker
10784*da0073e9SAndroid Build Coastguard Worker        # test output values
10785*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list(losses[0], losses[1:], atol=1e-4, rtol=0)
10786*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list(losses[0].squeeze(0), losses_no_bd, atol=1e-4, rtol=0)
10787*da0073e9SAndroid Build Coastguard Worker
10788*da0073e9SAndroid Build Coastguard Worker        # test gradient values
10789*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list(log_probs_refs[0].grad, [t.grad for t in log_probs_refs[1:]], atol=1e-4, rtol=0)
10790*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list(
10791*da0073e9SAndroid Build Coastguard Worker            log_probs_refs[0].grad.squeeze(1),
10792*da0073e9SAndroid Build Coastguard Worker            [t.grad for t in log_probs_no_bd_refs],
10793*da0073e9SAndroid Build Coastguard Worker            atol=1e-4,
10794*da0073e9SAndroid Build Coastguard Worker            rtol=0,
10795*da0073e9SAndroid Build Coastguard Worker        )
10796*da0073e9SAndroid Build Coastguard Worker
10797*da0073e9SAndroid Build Coastguard Worker        # checking the output's shape
10798*da0073e9SAndroid Build Coastguard Worker        # batch dim case should be (N,). no batch dim case should be ()
10799*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list((1,) if reduction == 'none' else (), [loss.shape for loss in losses])
10800*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list((), [loss.shape for loss in losses_no_bd])
10801*da0073e9SAndroid Build Coastguard Worker
10802*da0073e9SAndroid Build Coastguard Worker        # checking the gradient's shape
10803*da0073e9SAndroid Build Coastguard Worker        # batch dim case should have shape (T, N, C). no batch dim case should have shape (T, C)
10804*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list((input_length, 1, vocab_size), [t.grad.shape for t in log_probs_refs])
10805*da0073e9SAndroid Build Coastguard Worker        self._assertEqual_list((input_length, vocab_size), [t.grad.shape for t in log_probs_no_bd_refs])
10806*da0073e9SAndroid Build Coastguard Worker
10807*da0073e9SAndroid Build Coastguard Worker    def _ordered_sequence(self, device, dtype):
10808*da0073e9SAndroid Build Coastguard Worker        """Create ordered list of random sequences"""
10809*da0073e9SAndroid Build Coastguard Worker        seqs = [torch.empty(random.randint(1, 6), device=device, dtype=dtype)
10810*da0073e9SAndroid Build Coastguard Worker                for _ in range(5)]
10811*da0073e9SAndroid Build Coastguard Worker        seqs = [s.random_(-128, 128) for s in seqs]
10812*da0073e9SAndroid Build Coastguard Worker        ordered = sorted(seqs, key=len, reverse=True)
10813*da0073e9SAndroid Build Coastguard Worker        return ordered
10814*da0073e9SAndroid Build Coastguard Worker
10815*da0073e9SAndroid Build Coastguard Worker    def _padded_sequence(self, device, dtype):
10816*da0073e9SAndroid Build Coastguard Worker        """Create Tensor of random padded sequences"""
10817*da0073e9SAndroid Build Coastguard Worker        ordered = self._ordered_sequence(device, dtype)
10818*da0073e9SAndroid Build Coastguard Worker        lengths = [len(i) for i in ordered]
10819*da0073e9SAndroid Build Coastguard Worker        padded_tensor = rnn_utils.pad_sequence(ordered)
10820*da0073e9SAndroid Build Coastguard Worker        return padded_tensor, lengths
10821*da0073e9SAndroid Build Coastguard Worker
10822*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10823*da0073e9SAndroid Build Coastguard Worker    def test_device_mask(self, device):
10824*da0073e9SAndroid Build Coastguard Worker        for enforce_sorted in [True, False]:
10825*da0073e9SAndroid Build Coastguard Worker            padded, lengths = self._padded_sequence('cpu', torch.float)
10826*da0073e9SAndroid Build Coastguard Worker            packed = rnn_utils.pack_padded_sequence(
10827*da0073e9SAndroid Build Coastguard Worker                padded, lengths, enforce_sorted=enforce_sorted)
10828*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(packed.is_cuda)
10829*da0073e9SAndroid Build Coastguard Worker            packed = packed.to(device)
10830*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(packed.is_cuda)
10831*da0073e9SAndroid Build Coastguard Worker            unpacked, _ = rnn_utils.pad_packed_sequence(packed)
10832*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(unpacked.is_cuda)
10833*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unpacked.dtype, torch.float)
10834*da0073e9SAndroid Build Coastguard Worker
10835*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10836*da0073e9SAndroid Build Coastguard Worker    def test_overwrite_module_params_on_conversion_cpu_device(self, device):
10837*da0073e9SAndroid Build Coastguard Worker        # Test that under the current default settings
10838*da0073e9SAndroid Build Coastguard Worker        # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
10839*da0073e9SAndroid Build Coastguard Worker        # a view to a module's parameters is not pointing to the same storage as
10840*da0073e9SAndroid Build Coastguard Worker        # its base variable after converting the module to a different device.
10841*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(20, 10)
10842*da0073e9SAndroid Build Coastguard Worker        mw = m.weight[:]
10843*da0073e9SAndroid Build Coastguard Worker        m.to(device)
10844*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
10845*da0073e9SAndroid Build Coastguard Worker            # Without using `torch.no_grad()`, this will leak CUDA memory.
10846*da0073e9SAndroid Build Coastguard Worker            # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875)
10847*da0073e9SAndroid Build Coastguard Worker            mw[0][0] = 5
10848*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(mw[0][0].device.type == "cpu")
10849*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(mw._base[0][0].device.type == "cuda")
10850*da0073e9SAndroid Build Coastguard Worker
10851*da0073e9SAndroid Build Coastguard Worker        try:
10852*da0073e9SAndroid Build Coastguard Worker            torch.__future__.set_overwrite_module_params_on_conversion(True)
10853*da0073e9SAndroid Build Coastguard Worker
10854*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
10855*da0073e9SAndroid Build Coastguard Worker            # a view to a module's parameters is still pointing to the same storage as
10856*da0073e9SAndroid Build Coastguard Worker            # its base variable after converting the module to a different device.
10857*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10)
10858*da0073e9SAndroid Build Coastguard Worker            mw = m.weight[:]
10859*da0073e9SAndroid Build Coastguard Worker            m.to(device)
10860*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
10861*da0073e9SAndroid Build Coastguard Worker                mw[0][0] = 5
10862*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(mw[0][0] == mw._base[0][0])
10863*da0073e9SAndroid Build Coastguard Worker
10864*da0073e9SAndroid Build Coastguard Worker            # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
10865*da0073e9SAndroid Build Coastguard Worker            # `cpu_module.to("cuda")` doesn't preserve previous references to
10866*da0073e9SAndroid Build Coastguard Worker            # `cpu_module`'s parameters or gradients.
10867*da0073e9SAndroid Build Coastguard Worker            m = nn.Linear(20, 10)
10868*da0073e9SAndroid Build Coastguard Worker            m.weight.grad = torch.randn(10, 20)
10869*da0073e9SAndroid Build Coastguard Worker            weight_ref = m.weight
10870*da0073e9SAndroid Build Coastguard Worker            weight_grad_ref = m.weight.grad
10871*da0073e9SAndroid Build Coastguard Worker            m.to(device)
10872*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(weight_ref.device, m.weight.device)
10873*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(weight_grad_ref.device, m.weight.grad.device)
10874*da0073e9SAndroid Build Coastguard Worker        finally:
10875*da0073e9SAndroid Build Coastguard Worker            torch.__future__.set_overwrite_module_params_on_conversion(False)
10876*da0073e9SAndroid Build Coastguard Worker
10877*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10878*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float)
10879*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self, device, dtype):
10880*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(32, 100, device=device, dtype=dtype, requires_grad=True)
10881*da0073e9SAndroid Build Coastguard Worker        inputf = input.to(torch.float).detach().requires_grad_(True)
10882*da0073e9SAndroid Build Coastguard Worker        out = F.softmax(input, dim=-1, dtype=torch.float)
10883*da0073e9SAndroid Build Coastguard Worker        outf = F.softmax(inputf, dim=-1)
10884*da0073e9SAndroid Build Coastguard Worker        # should be bitwise equal
10885*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, outf, atol=0, rtol=0)
10886*da0073e9SAndroid Build Coastguard Worker        gO = torch.empty_like(outf).uniform_()
10887*da0073e9SAndroid Build Coastguard Worker        out.backward(gO)
10888*da0073e9SAndroid Build Coastguard Worker        outf.backward(gO)
10889*da0073e9SAndroid Build Coastguard Worker        # should be bitwise equal
10890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0, rtol=0)
10891*da0073e9SAndroid Build Coastguard Worker
10892*da0073e9SAndroid Build Coastguard Worker    def _test_batchnorm_grad(self, device, dtype=torch.double):
10893*da0073e9SAndroid Build Coastguard Worker        bs, n_feat, size_feat = 4, 5, 6
10894*da0073e9SAndroid Build Coastguard Worker        input = torch.arange(bs * n_feat * size_feat, device=device,
10895*da0073e9SAndroid Build Coastguard Worker                             requires_grad=True, dtype=dtype).view(bs, n_feat, size_feat)
10896*da0073e9SAndroid Build Coastguard Worker        weight = torch.arange(1, n_feat + 1, device=device, requires_grad=True, dtype=dtype)
10897*da0073e9SAndroid Build Coastguard Worker        bias = torch.arange(n_feat, device=device, requires_grad=True, dtype=dtype)
10898*da0073e9SAndroid Build Coastguard Worker        running_mean = 1 - torch.arange(n_feat, device=device, dtype=dtype)
10899*da0073e9SAndroid Build Coastguard Worker        running_var = 2 * torch.arange(n_feat, device=device, dtype=dtype)
10900*da0073e9SAndroid Build Coastguard Worker        for training in [False, True]:
10901*da0073e9SAndroid Build Coastguard Worker            _assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias,
10902*da0073e9SAndroid Build Coastguard Worker                                                              training, 0.1, 0.0001))
10903*da0073e9SAndroid Build Coastguard Worker
10904*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10905*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_grad(self, device):
10906*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_grad(device)
10907*da0073e9SAndroid Build Coastguard Worker
10908*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
10909*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
10910*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_grad(device)
10911*da0073e9SAndroid Build Coastguard Worker
10912*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10913*da0073e9SAndroid Build Coastguard Worker    def test_layernorm_half_precision(self):
10914*da0073e9SAndroid Build Coastguard Worker        width = 128
10915*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 5, width, device="cuda", dtype=torch.half) * 0.1
10916*da0073e9SAndroid Build Coastguard Worker        normalized_shape = (width,)
10917*da0073e9SAndroid Build Coastguard Worker        weight = torch.ones(width, device="cuda", dtype=torch.half)
10918*da0073e9SAndroid Build Coastguard Worker        bias = torch.zeros(width, device="cuda", dtype=torch.half)
10919*da0073e9SAndroid Build Coastguard Worker        eps = 1e-5
10920*da0073e9SAndroid Build Coastguard Worker
10921*da0073e9SAndroid Build Coastguard Worker        output_fp16 = torch.layer_norm(input, normalized_shape, weight, bias, eps)
10922*da0073e9SAndroid Build Coastguard Worker        output_fp32 = torch.layer_norm(input.float(), normalized_shape, weight.float(), bias.float(), eps).half()
10923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_fp16, output_fp32, atol=0, rtol=0)
10924*da0073e9SAndroid Build Coastguard Worker
10925*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
10926*da0073e9SAndroid Build Coastguard Worker    def test_layernorm_weight_bias(self):
10927*da0073e9SAndroid Build Coastguard Worker        width = 128
10928*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 5, width, device="cuda", dtype=torch.float32) * 0.1
10929*da0073e9SAndroid Build Coastguard Worker        normalized_shape = (width,)
10930*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(width, device="cuda", dtype=torch.float32)
10931*da0073e9SAndroid Build Coastguard Worker        weight = torch.ones(width, device="cuda", dtype=torch.float32)
10932*da0073e9SAndroid Build Coastguard Worker        bias = torch.zeros(width, device="cuda", dtype=torch.float32)
10933*da0073e9SAndroid Build Coastguard Worker        eps = 1e-5
10934*da0073e9SAndroid Build Coastguard Worker
10935*da0073e9SAndroid Build Coastguard Worker        out_none_weight = torch.layer_norm(input, normalized_shape, None, data, eps)
10936*da0073e9SAndroid Build Coastguard Worker        out_one_weight = torch.layer_norm(input, normalized_shape, weight, data, eps)
10937*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_none_weight, out_one_weight)
10938*da0073e9SAndroid Build Coastguard Worker
10939*da0073e9SAndroid Build Coastguard Worker        out_none_bias = torch.layer_norm(input, normalized_shape, data, None, eps)
10940*da0073e9SAndroid Build Coastguard Worker        out_zero_bias = torch.layer_norm(input, normalized_shape, data, bias, eps)
10941*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_none_bias, out_zero_bias)
10942*da0073e9SAndroid Build Coastguard Worker
10943*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
10944*da0073e9SAndroid Build Coastguard Worker    def test_hardsigmoid_grad(self, device):
10945*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10
10946*da0073e9SAndroid Build Coastguard Worker        inputs.requires_grad = True
10947*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(F.hardsigmoid, (inputs,)))
10948*da0073e9SAndroid Build Coastguard Worker
10949*da0073e9SAndroid Build Coastguard Worker    # currently fails on XLA
10950*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
10951*da0073e9SAndroid Build Coastguard Worker    def test_hardswish_grad(self, device):
10952*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10
10953*da0073e9SAndroid Build Coastguard Worker        inputs.requires_grad = True
10954*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(F.hardswish, (inputs,)))
10955*da0073e9SAndroid Build Coastguard Worker
10956*da0073e9SAndroid Build Coastguard Worker
10957*da0073e9SAndroid Build Coastguard Worker    def _test_batchnorm_eval(self, ndim, device, dtype, module_dtype=None):
10958*da0073e9SAndroid Build Coastguard Worker        module_dtype = module_dtype or dtype
10959*da0073e9SAndroid Build Coastguard Worker        module = nn.BatchNorm1d(3).to(device, module_dtype)
10960*da0073e9SAndroid Build Coastguard Worker        module.eval()
10961*da0073e9SAndroid Build Coastguard Worker
10962*da0073e9SAndroid Build Coastguard Worker        data = torch.rand([3] * ndim, device=device, dtype=dtype, requires_grad=True)
10963*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand([3] * ndim, device=device, dtype=dtype)
10964*da0073e9SAndroid Build Coastguard Worker
10965*da0073e9SAndroid Build Coastguard Worker        # 1st pass
10966*da0073e9SAndroid Build Coastguard Worker        res1 = module(data)
10967*da0073e9SAndroid Build Coastguard Worker        res1.backward(grad)
10968*da0073e9SAndroid Build Coastguard Worker        grad1 = data.grad.clone()
10969*da0073e9SAndroid Build Coastguard Worker
10970*da0073e9SAndroid Build Coastguard Worker        # 2nd pass
10971*da0073e9SAndroid Build Coastguard Worker        if data.grad is not None:
10972*da0073e9SAndroid Build Coastguard Worker            data.grad.data.zero_()
10973*da0073e9SAndroid Build Coastguard Worker
10974*da0073e9SAndroid Build Coastguard Worker        res2 = module(data)
10975*da0073e9SAndroid Build Coastguard Worker        res2.backward(grad)
10976*da0073e9SAndroid Build Coastguard Worker        grad2 = data.grad.clone()
10977*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
10978*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad1, grad2)
10979*da0073e9SAndroid Build Coastguard Worker
10980*da0073e9SAndroid Build Coastguard Worker        # track_running_stats=False
10981*da0073e9SAndroid Build Coastguard Worker        module = nn.BatchNorm1d(3, track_running_stats=False).to(device, module_dtype)
10982*da0073e9SAndroid Build Coastguard Worker
10983*da0073e9SAndroid Build Coastguard Worker        data = torch.rand(4, 3, device=device, dtype=dtype, requires_grad=True)
10984*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand(4, 3, device=device, dtype=dtype)
10985*da0073e9SAndroid Build Coastguard Worker
10986*da0073e9SAndroid Build Coastguard Worker        # 1st pass
10987*da0073e9SAndroid Build Coastguard Worker        res1 = module(data)
10988*da0073e9SAndroid Build Coastguard Worker        res1.backward(grad)
10989*da0073e9SAndroid Build Coastguard Worker        grad1 = data.grad.clone()
10990*da0073e9SAndroid Build Coastguard Worker
10991*da0073e9SAndroid Build Coastguard Worker        # set eval
10992*da0073e9SAndroid Build Coastguard Worker        module.eval()
10993*da0073e9SAndroid Build Coastguard Worker
10994*da0073e9SAndroid Build Coastguard Worker        # 2nd pass
10995*da0073e9SAndroid Build Coastguard Worker        if data.grad is not None:
10996*da0073e9SAndroid Build Coastguard Worker            data.grad.data.zero_()
10997*da0073e9SAndroid Build Coastguard Worker
10998*da0073e9SAndroid Build Coastguard Worker        res2 = module(data)
10999*da0073e9SAndroid Build Coastguard Worker        res2.backward(grad)
11000*da0073e9SAndroid Build Coastguard Worker        grad2 = data.grad.clone()
11001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
11002*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad1, grad2)
11003*da0073e9SAndroid Build Coastguard Worker
11004*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
11005*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.bfloat16)
11006*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_eval(self, device, dtype):
11007*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_eval(2, device, dtype)
11008*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_eval(3, device, dtype)
11009*da0073e9SAndroid Build Coastguard Worker
11010*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
11011*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
11012*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_eval(2, device, dtype)
11013*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_eval(3, device, dtype)
11014*da0073e9SAndroid Build Coastguard Worker
11015*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11016*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
11017*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_eval_mixed(self, device, dtype):
11018*da0073e9SAndroid Build Coastguard Worker        # Test bfloat16 input with float module
11019*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_eval(2, device, dtype, torch.float)
11020*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_eval(3, device, dtype, torch.float)
11021*da0073e9SAndroid Build Coastguard Worker
11022*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
11023*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
11024*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_eval(2, device, dtype, torch.float)
11025*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_eval(3, device, dtype, torch.float)
11026*da0073e9SAndroid Build Coastguard Worker
11027*da0073e9SAndroid Build Coastguard Worker    def _test_batchnorm_affine(self, ndim, device, dtype, module_dtype=None):
11028*da0073e9SAndroid Build Coastguard Worker        # Compare affine against no-op weights and bias
11029*da0073e9SAndroid Build Coastguard Worker        module_dtype = module_dtype or dtype
11030*da0073e9SAndroid Build Coastguard Worker        module = nn.BatchNorm1d(3, affine=False).to(device, module_dtype)
11031*da0073e9SAndroid Build Coastguard Worker        module_affine = nn.BatchNorm1d(3, affine=True).to(device, module_dtype)
11032*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
11033*da0073e9SAndroid Build Coastguard Worker            module_affine.weight.fill_(1.0)
11034*da0073e9SAndroid Build Coastguard Worker            module_affine.bias.zero_()
11035*da0073e9SAndroid Build Coastguard Worker
11036*da0073e9SAndroid Build Coastguard Worker        data = torch.rand([3] * ndim, device=device, dtype=dtype, requires_grad=True)
11037*da0073e9SAndroid Build Coastguard Worker        grad = torch.ones_like(data, requires_grad=False)
11038*da0073e9SAndroid Build Coastguard Worker
11039*da0073e9SAndroid Build Coastguard Worker        # With weights all ones and bias all zeros
11040*da0073e9SAndroid Build Coastguard Worker        res1 = module_affine(data)
11041*da0073e9SAndroid Build Coastguard Worker        res1.backward(grad)
11042*da0073e9SAndroid Build Coastguard Worker        grad1 = data.grad.clone()
11043*da0073e9SAndroid Build Coastguard Worker        data.grad.zero_()
11044*da0073e9SAndroid Build Coastguard Worker
11045*da0073e9SAndroid Build Coastguard Worker        # Without any weights or bias
11046*da0073e9SAndroid Build Coastguard Worker        res2 = module(data)
11047*da0073e9SAndroid Build Coastguard Worker        res2.backward(grad)
11048*da0073e9SAndroid Build Coastguard Worker        grad2 = data.grad
11049*da0073e9SAndroid Build Coastguard Worker
11050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
11051*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad1, grad2)
11052*da0073e9SAndroid Build Coastguard Worker
11053*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
11054*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.bfloat16)
11055*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_affine(self, device, dtype):
11056*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_affine(2, device, dtype)
11057*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_affine(3, device, dtype)
11058*da0073e9SAndroid Build Coastguard Worker
11059*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
11060*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
11061*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_affine(2, device, dtype)
11062*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_affine(3, device, dtype)
11063*da0073e9SAndroid Build Coastguard Worker
11064*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11065*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
11066*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_affine_mixed(self, device, dtype):
11067*da0073e9SAndroid Build Coastguard Worker        cudnn_enabled = [False]
11068*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
11069*da0073e9SAndroid Build Coastguard Worker            # TODO: Test fails with cudnn, see gh-62034
11070*da0073e9SAndroid Build Coastguard Worker            # cudnn_enabled = [False, True]
11071*da0073e9SAndroid Build Coastguard Worker            pass
11072*da0073e9SAndroid Build Coastguard Worker
11073*da0073e9SAndroid Build Coastguard Worker        # Test bfloat16 input with float module
11074*da0073e9SAndroid Build Coastguard Worker        for enabled in cudnn_enabled:
11075*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=enabled):
11076*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_affine(2, device, dtype, torch.float)
11077*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_affine(3, device, dtype, torch.float)
11078*da0073e9SAndroid Build Coastguard Worker
11079*da0073e9SAndroid Build Coastguard Worker    def _test_batchnorm_simple_average(self, device, dtype, module_dtype=None):
11080*da0073e9SAndroid Build Coastguard Worker        module_dtype = module_dtype or dtype
11081*da0073e9SAndroid Build Coastguard Worker        module = nn.BatchNorm1d(3, momentum=None).to(dtype=module_dtype, device=device)
11082*da0073e9SAndroid Build Coastguard Worker        zeros = torch.zeros(3, dtype=module_dtype, device=device)
11083*da0073e9SAndroid Build Coastguard Worker        ones = torch.ones(3, dtype=module_dtype, device=device)
11084*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_mean, zeros)
11085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_var, ones)
11086*da0073e9SAndroid Build Coastguard Worker
11087*da0073e9SAndroid Build Coastguard Worker        data1 = torch.rand(4, 3, dtype=dtype, device=device)
11088*da0073e9SAndroid Build Coastguard Worker        data2 = torch.rand(4, 3, dtype=dtype, device=device)
11089*da0073e9SAndroid Build Coastguard Worker
11090*da0073e9SAndroid Build Coastguard Worker        # 1st pass
11091*da0073e9SAndroid Build Coastguard Worker        res1 = module(data1)
11092*da0073e9SAndroid Build Coastguard Worker        running_mean1 = module.running_mean.clone()
11093*da0073e9SAndroid Build Coastguard Worker        running_var1 = module.running_var.clone()
11094*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(running_mean1, zeros)
11095*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(running_var1, ones)
11096*da0073e9SAndroid Build Coastguard Worker
11097*da0073e9SAndroid Build Coastguard Worker        # reset stats
11098*da0073e9SAndroid Build Coastguard Worker        module.reset_running_stats()
11099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_mean, zeros)
11100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_var, ones)
11101*da0073e9SAndroid Build Coastguard Worker
11102*da0073e9SAndroid Build Coastguard Worker        # 2nd pass
11103*da0073e9SAndroid Build Coastguard Worker        res2 = module(data2)
11104*da0073e9SAndroid Build Coastguard Worker        running_mean2 = module.running_mean.clone()
11105*da0073e9SAndroid Build Coastguard Worker        running_var2 = module.running_var.clone()
11106*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(running_mean2, zeros)
11107*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(running_var2, ones)
11108*da0073e9SAndroid Build Coastguard Worker
11109*da0073e9SAndroid Build Coastguard Worker        # reset stats
11110*da0073e9SAndroid Build Coastguard Worker        module.reset_running_stats()
11111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_mean, zeros)
11112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_var, ones)
11113*da0073e9SAndroid Build Coastguard Worker
11114*da0073e9SAndroid Build Coastguard Worker        # 3rd (combined) pass
11115*da0073e9SAndroid Build Coastguard Worker        res3 = module(data1)
11116*da0073e9SAndroid Build Coastguard Worker        res4 = module(data2)
11117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res3, res1)
11118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res4, res2)
11119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_mean, (running_mean1 + running_mean2) / 2)
11120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.running_var, (running_var1 + running_var2) / 2)
11121*da0073e9SAndroid Build Coastguard Worker
11122*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
11123*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float, torch.bfloat16)
11124*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_simple_average(self, device, dtype):
11125*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_simple_average(device, dtype)
11126*da0073e9SAndroid Build Coastguard Worker
11127*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
11128*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
11129*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_simple_average(device, dtype)
11130*da0073e9SAndroid Build Coastguard Worker
11131*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11132*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
11133*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_simple_average_mixed(self, device, dtype):
11134*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_simple_average(device, dtype, torch.float)
11135*da0073e9SAndroid Build Coastguard Worker
11136*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
11137*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
11138*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_simple_average(device, dtype, torch.float)
11139*da0073e9SAndroid Build Coastguard Worker
11140*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
11141*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
11142*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample_nan_inf(self, device, dtype):
11143*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype)
11144*da0073e9SAndroid Build Coastguard Worker        grid = torch.tensor([[[[nan, 0], [0, inf]]]], device=device, dtype=dtype)
11145*da0073e9SAndroid Build Coastguard Worker        for padding_mode in ('reflection', 'border', 'zeros'):
11146*da0073e9SAndroid Build Coastguard Worker            sample = torch.nn.functional.grid_sample(input=input, grid=grid, mode='nearest',
11147*da0073e9SAndroid Build Coastguard Worker                                                     padding_mode=padding_mode, align_corners=False)
11148*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sample, torch.zeros([1, 1, 1, 2], device=device, dtype=dtype))
11149*da0073e9SAndroid Build Coastguard Worker
11150*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764
11151*da0073e9SAndroid Build Coastguard Worker    def test_CTCLoss_empty_target(self, device):
11152*da0073e9SAndroid Build Coastguard Worker        target_lengths = [0, 0, 0]
11153*da0073e9SAndroid Build Coastguard Worker        input_lengths = [50, 50, 50]
11154*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(1, 15, (0,), dtype=torch.long, device=device)
11155*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
11156*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11157*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((loss >= 0).all().item())
11158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-log_probs.sum(0)[:, 0], loss)
11159*da0073e9SAndroid Build Coastguard Worker
11160*da0073e9SAndroid Build Coastguard Worker        target_lengths = [0, 9, 0]
11161*da0073e9SAndroid Build Coastguard Worker        input_lengths = [50, 50, 50]
11162*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(1, 15, (9,), dtype=torch.long, device=device)
11163*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.randn(50, 3, 15, dtype=torch.double, device=device).log_softmax(2)
11164*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11165*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((loss >= 0).all().item())
11166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-log_probs.sum(0)[[0, 2], 0], loss[[0, 2]])
11167*da0073e9SAndroid Build Coastguard Worker
11168*da0073e9SAndroid Build Coastguard Worker    # Merge into OpInfo?
11169*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIf(True, """Test is flaky on Linux and Windows, typical error message:
11170*da0073e9SAndroid Build Coastguard Worker                          https://github.com/pytorch/pytorch/issues/34870""")
11171*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764
11172*da0073e9SAndroid Build Coastguard Worker    def test_ctc_loss(self, device):
11173*da0073e9SAndroid Build Coastguard Worker        batch_size = 64
11174*da0073e9SAndroid Build Coastguard Worker        num_labels = 101
11175*da0073e9SAndroid Build Coastguard Worker        target_length = 15
11176*da0073e9SAndroid Build Coastguard Worker        gradcheck_input_size = 10
11177*da0073e9SAndroid Build Coastguard Worker
11178*da0073e9SAndroid Build Coastguard Worker        ZERO_NONE = 0
11179*da0073e9SAndroid Build Coastguard Worker        ZERO_SOME = 1
11180*da0073e9SAndroid Build Coastguard Worker        ZERO_ALL = 2
11181*da0073e9SAndroid Build Coastguard Worker
11182*da0073e9SAndroid Build Coastguard Worker        # input_length, vary_lengths, zero_lengths
11183*da0073e9SAndroid Build Coastguard Worker        tests = [(150, False, ZERO_NONE),
11184*da0073e9SAndroid Build Coastguard Worker                 (150, True, ZERO_NONE),
11185*da0073e9SAndroid Build Coastguard Worker                 (50, True, ZERO_SOME),
11186*da0073e9SAndroid Build Coastguard Worker                 (50, True, ZERO_ALL)]
11187*da0073e9SAndroid Build Coastguard Worker
11188*da0073e9SAndroid Build Coastguard Worker        if 'cuda' in device:
11189*da0073e9SAndroid Build Coastguard Worker            tests += [(50, False, ZERO_NONE),
11190*da0073e9SAndroid Build Coastguard Worker                      (50, True, ZERO_NONE),
11191*da0073e9SAndroid Build Coastguard Worker                      (150, True, ZERO_SOME),
11192*da0073e9SAndroid Build Coastguard Worker                      (150, True, ZERO_ALL)]
11193*da0073e9SAndroid Build Coastguard Worker
11194*da0073e9SAndroid Build Coastguard Worker        for input_length, vary_lengths, zero_mode in tests:
11195*da0073e9SAndroid Build Coastguard Worker            targets = torch.randint(1, num_labels, (batch_size, target_length),
11196*da0073e9SAndroid Build Coastguard Worker                                    device=device, dtype=torch.long)
11197*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(gradcheck_input_size, dtype=torch.double, device=device, requires_grad=True)
11198*da0073e9SAndroid Build Coastguard Worker            tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
11199*da0073e9SAndroid Build Coastguard Worker                                       device=device)
11200*da0073e9SAndroid Build Coastguard Worker            input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
11201*da0073e9SAndroid Build Coastguard Worker                              if vary_lengths or i == 0 else input_length) for i in range(batch_size)]
11202*da0073e9SAndroid Build Coastguard Worker            if zero_mode == ZERO_ALL:
11203*da0073e9SAndroid Build Coastguard Worker                target_lengths = [0 for _ in range(batch_size)]
11204*da0073e9SAndroid Build Coastguard Worker            else:
11205*da0073e9SAndroid Build Coastguard Worker                target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
11206*da0073e9SAndroid Build Coastguard Worker                                   if vary_lengths else target_length) for _ in range(batch_size)]
11207*da0073e9SAndroid Build Coastguard Worker                if zero_mode == ZERO_SOME:
11208*da0073e9SAndroid Build Coastguard Worker                    idxes = torch.randint(0, batch_size, (10,))
11209*da0073e9SAndroid Build Coastguard Worker                    for i in idxes:
11210*da0073e9SAndroid Build Coastguard Worker                        target_lengths[i] = 0
11211*da0073e9SAndroid Build Coastguard Worker
11212*da0073e9SAndroid Build Coastguard Worker            def ctc_after_softmax(x):
11213*da0073e9SAndroid Build Coastguard Worker                x_full = ((x[:, None] * tile_factors[None, :]).view(-1)[:input_length * batch_size * num_labels]
11214*da0073e9SAndroid Build Coastguard Worker                          .view(input_length, batch_size, num_labels))
11215*da0073e9SAndroid Build Coastguard Worker                log_probs = torch.log_softmax(x_full, 2)
11216*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths)
11217*da0073e9SAndroid Build Coastguard Worker
11218*da0073e9SAndroid Build Coastguard Worker            gradcheck(ctc_after_softmax, [x])
11219*da0073e9SAndroid Build Coastguard Worker
11220*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11221*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
11222*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfCudnnVersionLessThan(7600)
11223*da0073e9SAndroid Build Coastguard Worker    def test_ctc_loss_cudnn(self, device):
11224*da0073e9SAndroid Build Coastguard Worker        batch_size = 16
11225*da0073e9SAndroid Build Coastguard Worker        input_length = 30
11226*da0073e9SAndroid Build Coastguard Worker        num_labels = 101
11227*da0073e9SAndroid Build Coastguard Worker        target_length = 15
11228*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(1, num_labels, (batch_size * target_length,),
11229*da0073e9SAndroid Build Coastguard Worker                                device='cuda', dtype=torch.long)
11230*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.log_softmax(torch.randn(input_length, batch_size, num_labels, device='cuda', dtype=torch.float), 2)
11231*da0073e9SAndroid Build Coastguard Worker        log_probs.requires_grad_()
11232*da0073e9SAndroid Build Coastguard Worker
11233*da0073e9SAndroid Build Coastguard Worker        input_lengths = batch_size * [input_length]
11234*da0073e9SAndroid Build Coastguard Worker        target_lengths = batch_size * [target_length]
11235*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.randn(batch_size, device='cuda', dtype=torch.float)
11236*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=False):
11237*da0073e9SAndroid Build Coastguard Worker            loss_native = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11238*da0073e9SAndroid Build Coastguard Worker            grad_native, = torch.autograd.grad(loss_native, log_probs, grad_out)
11239*da0073e9SAndroid Build Coastguard Worker        loss_cudnn = torch.nn.functional.ctc_loss(log_probs, targets.to('cpu', torch.int32),
11240*da0073e9SAndroid Build Coastguard Worker                                                  input_lengths, target_lengths, reduction='none')
11241*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn))
11242*da0073e9SAndroid Build Coastguard Worker        grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
11243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
11244*da0073e9SAndroid Build Coastguard Worker
11245*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11246*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
11247*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfCudnnVersionLessThan(8000)
11248*da0073e9SAndroid Build Coastguard Worker    def test_ctc_loss_cudnn_tensor(self, device):
11249*da0073e9SAndroid Build Coastguard Worker        batch_size = 16
11250*da0073e9SAndroid Build Coastguard Worker        input_length = 30
11251*da0073e9SAndroid Build Coastguard Worker        num_labels = 101
11252*da0073e9SAndroid Build Coastguard Worker        target_length = 15
11253*da0073e9SAndroid Build Coastguard Worker        targets = torch.randint(1, num_labels, (batch_size * target_length,),
11254*da0073e9SAndroid Build Coastguard Worker                                device='cuda', dtype=torch.long)
11255*da0073e9SAndroid Build Coastguard Worker        log_probs = torch.log_softmax(torch.randn(input_length, batch_size, num_labels, device='cuda', dtype=torch.float), 2)
11256*da0073e9SAndroid Build Coastguard Worker        log_probs.requires_grad_()
11257*da0073e9SAndroid Build Coastguard Worker
11258*da0073e9SAndroid Build Coastguard Worker        input_lengths = batch_size * [input_length]
11259*da0073e9SAndroid Build Coastguard Worker        input_lengths = torch.linspace(start=15, end=input_length, steps=batch_size, dtype=torch.long, device='cuda')
11260*da0073e9SAndroid Build Coastguard Worker        target_lengths = torch.tensor(batch_size * [target_length], dtype=torch.long, device='cuda')
11261*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.randn(batch_size, device='cuda', dtype=torch.float)
11262*da0073e9SAndroid Build Coastguard Worker        with torch.backends.cudnn.flags(enabled=False):
11263*da0073e9SAndroid Build Coastguard Worker            loss_native = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
11264*da0073e9SAndroid Build Coastguard Worker            grad_native, = torch.autograd.grad(loss_native, log_probs, grad_out)
11265*da0073e9SAndroid Build Coastguard Worker        loss_cudnn = torch.nn.functional.ctc_loss(log_probs,
11266*da0073e9SAndroid Build Coastguard Worker                                                  targets.to('cuda', torch.int32),
11267*da0073e9SAndroid Build Coastguard Worker                                                  input_lengths.to('cuda', torch.int32),
11268*da0073e9SAndroid Build Coastguard Worker                                                  target_lengths.to('cuda', torch.int32),
11269*da0073e9SAndroid Build Coastguard Worker                                                  reduction='none')
11270*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("Cudnn" in str(loss_cudnn.grad_fn))
11271*da0073e9SAndroid Build Coastguard Worker        grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
11272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
11273*da0073e9SAndroid Build Coastguard Worker
11274*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # RuntimeError: LSTM with projections is not currently supported with MPS.
11275*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
11276*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
11277*da0073e9SAndroid Build Coastguard Worker    @tf32_on_and_off(0.005)
11278*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
11279*da0073e9SAndroid Build Coastguard Worker    def test_variable_sequence(self, device, dtype):
11280*da0073e9SAndroid Build Coastguard Worker        def pad(var, length):
11281*da0073e9SAndroid Build Coastguard Worker            if var.size(0) == length:
11282*da0073e9SAndroid Build Coastguard Worker                return var
11283*da0073e9SAndroid Build Coastguard Worker            return torch.cat([var, var.new_zeros(length - var.size(0), *var.size()[1:])])
11284*da0073e9SAndroid Build Coastguard Worker
11285*da0073e9SAndroid Build Coastguard Worker        def maybe_index_tuple(maybe_tuple_of_tensors, index):
11286*da0073e9SAndroid Build Coastguard Worker            if maybe_tuple_of_tensors is None:
11287*da0073e9SAndroid Build Coastguard Worker                return None
11288*da0073e9SAndroid Build Coastguard Worker            return tuple(maybe_tuple_of_tensors[j][:, index:index + 1, :].contiguous()
11289*da0073e9SAndroid Build Coastguard Worker                         for j in range(2))
11290*da0073e9SAndroid Build Coastguard Worker
11291*da0073e9SAndroid Build Coastguard Worker        def check_lengths(lengths, enforce_sorted, use_default_hiddens, proj_size):
11292*da0073e9SAndroid Build Coastguard Worker            input_size = 3
11293*da0073e9SAndroid Build Coastguard Worker            hidden_size = 4
11294*da0073e9SAndroid Build Coastguard Worker            num_layers = 2
11295*da0073e9SAndroid Build Coastguard Worker            bidirectional = True
11296*da0073e9SAndroid Build Coastguard Worker
11297*da0073e9SAndroid Build Coastguard Worker            max_length = max(lengths)
11298*da0073e9SAndroid Build Coastguard Worker            x_leaf = torch.randn(max_length, len(lengths), input_size, device=device,
11299*da0073e9SAndroid Build Coastguard Worker                                 dtype=dtype, requires_grad=True)
11300*da0073e9SAndroid Build Coastguard Worker            num_directions = 2 if bidirectional else 1
11301*da0073e9SAndroid Build Coastguard Worker            lstm = nn.LSTM(input_size, hidden_size, bidirectional=bidirectional,
11302*da0073e9SAndroid Build Coastguard Worker                           num_layers=num_layers, proj_size=proj_size).to(device, dtype)
11303*da0073e9SAndroid Build Coastguard Worker            lstm2 = deepcopy(lstm).to(device, dtype)
11304*da0073e9SAndroid Build Coastguard Worker            x = x_leaf
11305*da0073e9SAndroid Build Coastguard Worker
11306*da0073e9SAndroid Build Coastguard Worker            hidden0 = None
11307*da0073e9SAndroid Build Coastguard Worker            if not use_default_hiddens:
11308*da0073e9SAndroid Build Coastguard Worker                real_hidden_size = hidden_size if proj_size == 0 else proj_size
11309*da0073e9SAndroid Build Coastguard Worker                hidden0 = (torch.randn(num_directions * num_layers, len(lengths), real_hidden_size,
11310*da0073e9SAndroid Build Coastguard Worker                                       device=device, dtype=dtype),
11311*da0073e9SAndroid Build Coastguard Worker                           torch.randn(num_directions * num_layers, len(lengths), hidden_size,
11312*da0073e9SAndroid Build Coastguard Worker                                       device=device, dtype=dtype))
11313*da0073e9SAndroid Build Coastguard Worker
11314*da0073e9SAndroid Build Coastguard Worker            # Compute sequences separately
11315*da0073e9SAndroid Build Coastguard Worker            seq_outs = []
11316*da0073e9SAndroid Build Coastguard Worker            seq_hiddens = []
11317*da0073e9SAndroid Build Coastguard Worker            for i, l in enumerate(lengths):
11318*da0073e9SAndroid Build Coastguard Worker                hidden_i = maybe_index_tuple(hidden0, i)
11319*da0073e9SAndroid Build Coastguard Worker                out, hid = lstm2(x[:l, i:i + 1], hidden_i)
11320*da0073e9SAndroid Build Coastguard Worker                out_pad = pad(out, max_length)
11321*da0073e9SAndroid Build Coastguard Worker                seq_outs.append(out_pad)
11322*da0073e9SAndroid Build Coastguard Worker                seq_hiddens.append(hid)
11323*da0073e9SAndroid Build Coastguard Worker            seq_out = torch.cat(seq_outs, 1)
11324*da0073e9SAndroid Build Coastguard Worker            seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))
11325*da0073e9SAndroid Build Coastguard Worker
11326*da0073e9SAndroid Build Coastguard Worker            # Use packed format
11327*da0073e9SAndroid Build Coastguard Worker            packed = rnn_utils.pack_padded_sequence(x, lengths, enforce_sorted=enforce_sorted)
11328*da0073e9SAndroid Build Coastguard Worker            packed_out, packed_hidden = lstm(packed, hidden0)
11329*da0073e9SAndroid Build Coastguard Worker            unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)
11330*da0073e9SAndroid Build Coastguard Worker
11331*da0073e9SAndroid Build Coastguard Worker            # Check forward
11332*da0073e9SAndroid Build Coastguard Worker            prec = dtype2prec_DONTUSE[dtype]
11333*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(packed_hidden, seq_hidden, atol=prec, rtol=0)
11334*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unpacked, seq_out, atol=prec, rtol=0)
11335*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unpacked_len, lengths, atol=prec, rtol=0)
11336*da0073e9SAndroid Build Coastguard Worker
11337*da0073e9SAndroid Build Coastguard Worker            # Check backward
11338*da0073e9SAndroid Build Coastguard Worker            seq_out.sum().backward()
11339*da0073e9SAndroid Build Coastguard Worker            grad_x = x_leaf.grad.data.clone()
11340*da0073e9SAndroid Build Coastguard Worker            x_leaf.grad.data.zero_()
11341*da0073e9SAndroid Build Coastguard Worker            unpacked.sum().backward()
11342*da0073e9SAndroid Build Coastguard Worker
11343*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_leaf.grad, grad_x, atol=dtype2prec_DONTUSE[dtype], rtol=0)
11344*da0073e9SAndroid Build Coastguard Worker            for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
11345*da0073e9SAndroid Build Coastguard Worker                prec = dtype2prec_DONTUSE[dtype]
11346*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.float16:
11347*da0073e9SAndroid Build Coastguard Worker                    prec = 4e-2
11348*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p1.grad, p2.grad, atol=prec, rtol=0)
11349*da0073e9SAndroid Build Coastguard Worker
11350*da0073e9SAndroid Build Coastguard Worker        tests = [
11351*da0073e9SAndroid Build Coastguard Worker            # enforce_sorted, lengths
11352*da0073e9SAndroid Build Coastguard Worker            [True, [5]],
11353*da0073e9SAndroid Build Coastguard Worker            [False, [5]],
11354*da0073e9SAndroid Build Coastguard Worker            [True, [10, 10, 6, 2, 2, 1, 1]],
11355*da0073e9SAndroid Build Coastguard Worker            [False, [10, 10, 6, 2, 2, 1, 1]],
11356*da0073e9SAndroid Build Coastguard Worker            [False, [2, 1, 3, 2, 10, 5, 3]],
11357*da0073e9SAndroid Build Coastguard Worker        ]
11358*da0073e9SAndroid Build Coastguard Worker
11359*da0073e9SAndroid Build Coastguard Worker        for enforce_sorted, seq_lens, in tests:
11360*da0073e9SAndroid Build Coastguard Worker            for use_default_hiddens in (True, False):
11361*da0073e9SAndroid Build Coastguard Worker                for proj_size in [0, 2]:
11362*da0073e9SAndroid Build Coastguard Worker                    check_lengths(seq_lens, enforce_sorted, use_default_hiddens, proj_size)
11363*da0073e9SAndroid Build Coastguard Worker
11364*da0073e9SAndroid Build Coastguard Worker    def _test_batchnorm_update_stats(self, device, dtype=torch.float):
11365*da0073e9SAndroid Build Coastguard Worker        module = nn.BatchNorm1d(3).to(device, dtype)
11366*da0073e9SAndroid Build Coastguard Worker
11367*da0073e9SAndroid Build Coastguard Worker        data = torch.rand(4, 3, device=device, dtype=dtype)
11368*da0073e9SAndroid Build Coastguard Worker
11369*da0073e9SAndroid Build Coastguard Worker        # training pass
11370*da0073e9SAndroid Build Coastguard Worker        old_running_mean = module.running_mean.clone()
11371*da0073e9SAndroid Build Coastguard Worker        old_running_var = module.running_var.clone()
11372*da0073e9SAndroid Build Coastguard Worker        old_num_batches_tracked = module.num_batches_tracked.clone()
11373*da0073e9SAndroid Build Coastguard Worker        module(data)
11374*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(old_running_mean, module.running_mean)
11375*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(old_running_var, module.running_var)
11376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked)
11377*da0073e9SAndroid Build Coastguard Worker
11378*da0073e9SAndroid Build Coastguard Worker        # eval pass
11379*da0073e9SAndroid Build Coastguard Worker        module.eval()
11380*da0073e9SAndroid Build Coastguard Worker        old_running_mean = module.running_mean.clone()
11381*da0073e9SAndroid Build Coastguard Worker        old_running_var = module.running_var.clone()
11382*da0073e9SAndroid Build Coastguard Worker        old_num_batches_tracked = module.num_batches_tracked.clone()
11383*da0073e9SAndroid Build Coastguard Worker        module(data)
11384*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_running_mean, module.running_mean)
11385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_running_var, module.running_var)
11386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(old_num_batches_tracked, module.num_batches_tracked)
11387*da0073e9SAndroid Build Coastguard Worker
11388*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_update_stats(self, device):
11389*da0073e9SAndroid Build Coastguard Worker        self._test_batchnorm_update_stats(device)
11390*da0073e9SAndroid Build Coastguard Worker
11391*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda' and self.has_cudnn():
11392*da0073e9SAndroid Build Coastguard Worker            with torch.backends.cudnn.flags(enabled=False):
11393*da0073e9SAndroid Build Coastguard Worker                self._test_batchnorm_update_stats(device)
11394*da0073e9SAndroid Build Coastguard Worker
11395*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
11396*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float16)
11397*da0073e9SAndroid Build Coastguard Worker    def test_activations_bfloat16_half_cpu(self, device, dtype):
11398*da0073e9SAndroid Build Coastguard Worker        def test_helper(fn, device, inp_dims, prec=None):
11399*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(37)
11400*da0073e9SAndroid Build Coastguard Worker            # bfloat16/half compute
11401*da0073e9SAndroid Build Coastguard Worker            fn = fn.to(dtype=dtype)
11402*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(inp_dims, dtype=dtype, device=device, requires_grad=True)
11403*da0073e9SAndroid Build Coastguard Worker            out = fn(input)
11404*da0073e9SAndroid Build Coastguard Worker            grad_input = torch.randn_like(out, dtype=dtype, device=device)
11405*da0073e9SAndroid Build Coastguard Worker            out.backward(grad_input)
11406*da0073e9SAndroid Build Coastguard Worker
11407*da0073e9SAndroid Build Coastguard Worker            # fp32 compute
11408*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().float().requires_grad_(True)
11409*da0073e9SAndroid Build Coastguard Worker            out2 = fn.float()(input2)
11410*da0073e9SAndroid Build Coastguard Worker            grad_input2 = grad_input.detach().clone().float()
11411*da0073e9SAndroid Build Coastguard Worker            out2.backward(grad_input2)
11412*da0073e9SAndroid Build Coastguard Worker
11413*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.dtype, dtype)
11414*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.dtype, dtype)
11415*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2.to(dtype=dtype), atol=prec, rtol=prec)
11416*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.data, input2.grad.data.to(dtype=dtype), atol=prec, rtol=prec)
11417*da0073e9SAndroid Build Coastguard Worker
11418*da0073e9SAndroid Build Coastguard Worker        shapes = [[1, 3, 1, 6], [1, 3, 1, 128], [1, 3, 256, 256]]
11419*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
11420*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.LogSigmoid(), device, shape)
11421*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Hardsigmoid(), device, shape)
11422*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Hardshrink(), device, shape)
11423*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Softshrink(), device, shape)
11424*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Hardswish(), device, shape)
11425*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Softplus(), device, shape)
11426*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.SiLU(), device, shape)
11427*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Hardtanh(), device, shape)
11428*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Mish(), device, shape)
11429*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.ELU(), device, shape)
11430*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.PReLU(), device, shape)
11431*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.GLU(), device, shape, prec=1e-2)
11432*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Threshold(0.1, 20), device, shape)
11433*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.GELU(), device, shape)
11434*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.Hardtanh(), device, shape)
11435*da0073e9SAndroid Build Coastguard Worker            test_helper(torch.nn.LeakyReLU(), device, shape)
11436*da0073e9SAndroid Build Coastguard Worker
11437*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11438*da0073e9SAndroid Build Coastguard Worker    def test_activations_bfloat16(self, device):
11439*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(self, torch.nn.ReLU(), device, inp_dims=(5), prec=1e-2)
11440*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(self, torch.nn.Threshold(0.1, 20), device, inp_dims=(5), prec=1e-2)
11441*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(self, torch.nn.ELU(), device, inp_dims=(5), prec=1e-2)
11442*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(self, torch.nn.Softplus(), device, inp_dims=(5), prec=1e-2)
11443*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(self, torch.nn.Hardshrink(), device, inp_dims=(5), prec=1e-2)
11444*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(self, torch.nn.Softshrink(), device, inp_dims=(5), prec=1e-2)
11445*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(self, torch.nn.LeakyReLU(), device, inp_dims=(5), prec=1e-2)
11446*da0073e9SAndroid Build Coastguard Worker
11447*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
11448*da0073e9SAndroid Build Coastguard Worker    def test_softmax_bfloat16(self, device):
11449*da0073e9SAndroid Build Coastguard Worker        for dim in [0, 1, 2, 3]:
11450*da0073e9SAndroid Build Coastguard Worker            _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=1e-2)
11451*da0073e9SAndroid Build Coastguard Worker            # test softmax with large input value which casues exp() to overflow
11452*da0073e9SAndroid Build Coastguard Worker            _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0)
11453*da0073e9SAndroid Build Coastguard Worker
11454*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_mismatched_batch(self, device):
11455*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((10, 3), requires_grad=True, device=device)
11456*da0073e9SAndroid Build Coastguard Worker        # t should have size (10,)
11457*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((3,), dtype=torch.int64, device=device)
11458*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
11459*da0073e9SAndroid Build Coastguard Worker            F.nll_loss(x, t)
11460*da0073e9SAndroid Build Coastguard Worker
11461*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_out_of_bounds_ignore_index(self, device):
11462*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(6, 3, requires_grad=True, device=device)
11463*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
11464*da0073e9SAndroid Build Coastguard Worker        for reduction in ['mean', 'none']:
11465*da0073e9SAndroid Build Coastguard Worker            F.nll_loss(x, t, ignore_index=255, reduction=reduction).sum().backward()
11466*da0073e9SAndroid Build Coastguard Worker
11467*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_invalid_target_dim(self, device):
11468*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((10, 3), device=device)
11469*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((10, 2), dtype=torch.int64, device=device)
11470*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
11471*da0073e9SAndroid Build Coastguard Worker            F.nll_loss(x, t)
11472*da0073e9SAndroid Build Coastguard Worker
11473*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_invalid_weights(self, device):
11474*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((10, 3), device=device)
11475*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3)
11476*da0073e9SAndroid Build Coastguard Worker        invalid_weights = [
11477*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, device=device),
11478*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 3, device=device),
11479*da0073e9SAndroid Build Coastguard Worker        ]
11480*da0073e9SAndroid Build Coastguard Worker        msg = "weight tensor should be defined either for all 3 classes or no classes"
11481*da0073e9SAndroid Build Coastguard Worker        for weight in invalid_weights:
11482*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
11483*da0073e9SAndroid Build Coastguard Worker                F.nll_loss(x, t, weight=weight)
11484*da0073e9SAndroid Build Coastguard Worker
11485*da0073e9SAndroid Build Coastguard Worker    # Ref: https://github.com/pytorch/pytorch/issue/85005
11486*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11487*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("120GB", "cpu")
11488*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("45GB", "cuda")
11489*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("reduction", ("none", "mean", "sum"))
11490*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_large_tensor(self, device, reduction):
11491*da0073e9SAndroid Build Coastguard Worker        shape = [int(2 ** 16), int(2 ** 16) + 1]
11492*da0073e9SAndroid Build Coastguard Worker
11493*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(shape, device=device, dtype=torch.float32, requires_grad=True)
11494*da0073e9SAndroid Build Coastguard Worker        labels = torch.randint(shape[0], (shape[0],), dtype=torch.long, device=device)
11495*da0073e9SAndroid Build Coastguard Worker
11496*da0073e9SAndroid Build Coastguard Worker        out = F.nll_loss(input, labels, reduction=reduction)
11497*da0073e9SAndroid Build Coastguard Worker
11498*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
11499*da0073e9SAndroid Build Coastguard Worker            input_cpu = input.cpu().float().requires_grad_()
11500*da0073e9SAndroid Build Coastguard Worker            labels_cpu = labels.cpu()
11501*da0073e9SAndroid Build Coastguard Worker        out_cpu = F.nll_loss(input_cpu, labels_cpu, reduction=reduction)
11502*da0073e9SAndroid Build Coastguard Worker        # workaround to reduce memory usage vs. self.assertEqual, see #84944
11503*da0073e9SAndroid Build Coastguard Worker        rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
11504*da0073e9SAndroid Build Coastguard Worker        if reduction == "sum":
11505*da0073e9SAndroid Build Coastguard Worker            orig_rtol, orig_atol = rtol, atol
11506*da0073e9SAndroid Build Coastguard Worker            rtol, atol = 7 * rtol, 3 * atol
11507*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
11508*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(out.cpu(), out_cpu, rtol=rtol, atol=atol))
11509*da0073e9SAndroid Build Coastguard Worker        if reduction == "sum":
11510*da0073e9SAndroid Build Coastguard Worker            rtol, atol = orig_rtol, orig_atol
11511*da0073e9SAndroid Build Coastguard Worker
11512*da0073e9SAndroid Build Coastguard Worker        if reduction != "none":
11513*da0073e9SAndroid Build Coastguard Worker            out.backward()
11514*da0073e9SAndroid Build Coastguard Worker            out_cpu.backward()
11515*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
11516*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol))
11517*da0073e9SAndroid Build Coastguard Worker
11518*da0073e9SAndroid Build Coastguard Worker    # Ref: https://github.com/pytorch/pytorch/issue/108345
11519*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11520*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("20GB", "cpu")
11521*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("20GB", "cuda")
11522*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("reduction", ("none", "mean", "sum"))
11523*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_64bit(self, device, reduction):
11524*da0073e9SAndroid Build Coastguard Worker        labels = torch.zeros(190, 50, dtype=torch.long, device=device)
11525*da0073e9SAndroid Build Coastguard Worker        logits = torch.ones(190, 229000, 50, dtype=torch.float, device=device)
11526*da0073e9SAndroid Build Coastguard Worker        loss = torch.nn.functional.cross_entropy(logits, labels)
11527*da0073e9SAndroid Build Coastguard Worker        loss_cpu = torch.nn.functional.cross_entropy(logits.cpu(), labels.cpu())
11528*da0073e9SAndroid Build Coastguard Worker        print(logits.numel(), labels.numel(), loss.numel())
11529*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(loss_cpu, loss.cpu(), rtol=1e-4, atol=1e-4))
11530*da0073e9SAndroid Build Coastguard Worker
11531*da0073e9SAndroid Build Coastguard Worker    def _nll_loss_helper(self, input_size, reduction, expected, device):
11532*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(input_size, requires_grad=True, device=device)
11533*da0073e9SAndroid Build Coastguard Worker        num_channels = input_size[1]
11534*da0073e9SAndroid Build Coastguard Worker        target_size = (input_size[0], ) + tuple(input_size[2:])
11535*da0073e9SAndroid Build Coastguard Worker        target = torch.randint(num_channels, target_size, device=device)
11536*da0073e9SAndroid Build Coastguard Worker
11537*da0073e9SAndroid Build Coastguard Worker        output = F.nll_loss(input, target, reduction=reduction)
11538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, expected, exact_dtype=False)
11539*da0073e9SAndroid Build Coastguard Worker
11540*da0073e9SAndroid Build Coastguard Worker        output.sum().backward()
11541*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad.size(), input.size())
11542*da0073e9SAndroid Build Coastguard Worker
11543*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_empty_tensor_reduction_none(self, device):
11544*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device)
11545*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device)
11546*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 0, 7], "none", torch.empty([2, 0, 7], device=device), device)
11547*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device)
11548*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device)
11549*da0073e9SAndroid Build Coastguard Worker
11550*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431
11551*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_empty_tensor_reduction_mean(self, device):
11552*da0073e9SAndroid Build Coastguard Worker        nan = torch.tensor(float('nan'), device=device)
11553*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([0, 3], "mean", nan, device)
11554*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([0, 3, 5, 7], "mean", nan, device)
11555*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 0, 7], "mean", nan, device)
11556*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device)
11557*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device)
11558*da0073e9SAndroid Build Coastguard Worker
11559*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED https://github.com/pytorch/pytorch/issues/134431
11560*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_empty_tensor_reduction_sum(self, device):
11561*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor(0, device=device)
11562*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([0, 3], "sum", zero, device)
11563*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([0, 3, 5, 7], "sum", zero, device)
11564*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 0, 7], "sum", zero, device)
11565*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device)
11566*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device)
11567*da0073e9SAndroid Build Coastguard Worker
11568*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # AssertionError: Expected nan but got 0.0.
11569*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_total_weight_is_zero(self, device):
11570*da0073e9SAndroid Build Coastguard Worker
11571*da0073e9SAndroid Build Coastguard Worker        def helper(input_size):
11572*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(input_size, requires_grad=True, device=device)
11573*da0073e9SAndroid Build Coastguard Worker            num_channels = input_size[1]
11574*da0073e9SAndroid Build Coastguard Worker            target_size = (input_size[0], ) + tuple(input_size[2:])
11575*da0073e9SAndroid Build Coastguard Worker            target = torch.zeros(target_size, dtype=torch.long, device=device)
11576*da0073e9SAndroid Build Coastguard Worker            weight = torch.zeros([num_channels], device=device)
11577*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.nll_loss(input, target, weight, reduction="sum").item(), 0.)
11578*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.nll_loss(input, target, weight, reduction="mean").item(), float("nan"))
11579*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.nll_loss(input, target, weight, reduction="none"), torch.zeros(target.shape, device=device))
11580*da0073e9SAndroid Build Coastguard Worker
11581*da0073e9SAndroid Build Coastguard Worker        helper([2, 3])
11582*da0073e9SAndroid Build Coastguard Worker        helper([2, 3, 5, 7])
11583*da0073e9SAndroid Build Coastguard Worker        helper([2, 3, 5, 7, 9])
11584*da0073e9SAndroid Build Coastguard Worker
11585*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # AssertionError: Expected nan but got 0.0.
11586*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_all_ignored(self, device):
11587*da0073e9SAndroid Build Coastguard Worker
11588*da0073e9SAndroid Build Coastguard Worker        def helper(input_size):
11589*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(input_size, device=device)
11590*da0073e9SAndroid Build Coastguard Worker            num_channels = input_size[1]
11591*da0073e9SAndroid Build Coastguard Worker            target_size = (input_size[0], ) + tuple(input_size[2:])
11592*da0073e9SAndroid Build Coastguard Worker            target = torch.zeros(target_size, dtype=torch.long, device=device)
11593*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="sum").item(), 0)
11594*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="mean").item(), float("nan"))
11595*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.nll_loss(input, target, ignore_index=0, reduction="none"), torch.zeros(target.shape, device=device))
11596*da0073e9SAndroid Build Coastguard Worker
11597*da0073e9SAndroid Build Coastguard Worker        helper([2, 3])
11598*da0073e9SAndroid Build Coastguard Worker        helper([2, 3, 5, 7])
11599*da0073e9SAndroid Build Coastguard Worker        helper([2, 3, 5, 7, 9])
11600*da0073e9SAndroid Build Coastguard Worker
11601*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_byte_target_matches_long(self, device):
11602*da0073e9SAndroid Build Coastguard Worker        N, C = 10, 4
11603*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(N, C, device=device, requires_grad=True)
11604*da0073e9SAndroid Build Coastguard Worker        target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
11605*da0073e9SAndroid Build Coastguard Worker
11606*da0073e9SAndroid Build Coastguard Worker        def compute_result_and_gradient(reduction, target_dtype):
11607*da0073e9SAndroid Build Coastguard Worker            input_ = input.detach()
11608*da0073e9SAndroid Build Coastguard Worker            input_.requires_grad_()
11609*da0073e9SAndroid Build Coastguard Worker
11610*da0073e9SAndroid Build Coastguard Worker            prob = F.log_softmax(input_, dim=-1)
11611*da0073e9SAndroid Build Coastguard Worker            loss = nn.NLLLoss(reduction=reduction)
11612*da0073e9SAndroid Build Coastguard Worker            result = loss(prob, target.to(target_dtype))
11613*da0073e9SAndroid Build Coastguard Worker            result.sum().backward()
11614*da0073e9SAndroid Build Coastguard Worker
11615*da0073e9SAndroid Build Coastguard Worker            return result, input_.grad
11616*da0073e9SAndroid Build Coastguard Worker
11617*da0073e9SAndroid Build Coastguard Worker        for reduction in ["none", "mean", "sum"]:
11618*da0073e9SAndroid Build Coastguard Worker            result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
11619*da0073e9SAndroid Build Coastguard Worker            result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
11620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_long, result_byte)
11621*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_long, grad_byte)
11622*da0073e9SAndroid Build Coastguard Worker
11623*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11624*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
11625*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.float32)
11626*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_2d_out_of_bounds_class_index(self, device, dtype):
11627*da0073e9SAndroid Build Coastguard Worker        # Test for issue #117532
11628*da0073e9SAndroid Build Coastguard Worker        # Run in a different process to prevent the device-side assert from affecting other tests
11629*da0073e9SAndroid Build Coastguard Worker        stderr = TestCase.runWithPytorchAPIUsageStderr(f"""\
11630*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
11631*da0073e9SAndroid Build Coastguard Worker
11632*da0073e9SAndroid Build Coastguard Workerimport torch
11633*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
11634*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (run_tests, TestCase)
11635*da0073e9SAndroid Build Coastguard Worker
11636*da0073e9SAndroid Build Coastguard Workerclass TestThatContainsCUDAAssert(TestCase):
11637*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_2d_out_of_bounds_class_index(self):
11638*da0073e9SAndroid Build Coastguard Worker        device = '{str(device)}'
11639*da0073e9SAndroid Build Coastguard Worker        dtype = {str(dtype).strip("'")}
11640*da0073e9SAndroid Build Coastguard Worker        ignore_index = 255
11641*da0073e9SAndroid Build Coastguard Worker        b = 10
11642*da0073e9SAndroid Build Coastguard Worker        n_classes = 3
11643*da0073e9SAndroid Build Coastguard Worker        w = 768
11644*da0073e9SAndroid Build Coastguard Worker        h = 1024
11645*da0073e9SAndroid Build Coastguard Worker        pred = torch.randn(b, n_classes, w, h, dtype=dtype, device=device)
11646*da0073e9SAndroid Build Coastguard Worker        labels = torch.zeros(b, w, h, dtype=torch.int64, device=device)
11647*da0073e9SAndroid Build Coastguard Worker        labels[5, 200, 200] = ignore_index
11648*da0073e9SAndroid Build Coastguard Worker        # Set invalid class index
11649*da0073e9SAndroid Build Coastguard Worker        labels[5, 200, 200] = 254
11650*da0073e9SAndroid Build Coastguard Worker
11651*da0073e9SAndroid Build Coastguard Worker        x = F.cross_entropy(
11652*da0073e9SAndroid Build Coastguard Worker            pred, labels, reduction="none", ignore_index=ignore_index
11653*da0073e9SAndroid Build Coastguard Worker        )
11654*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
11655*da0073e9SAndroid Build Coastguard Worker
11656*da0073e9SAndroid Build Coastguard Worker
11657*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
11658*da0073e9SAndroid Build Coastguard Worker    run_tests()
11659*da0073e9SAndroid Build Coastguard Worker        """)
11660*da0073e9SAndroid Build Coastguard Worker        self.assertIn('CUDA error: device-side assert triggered', stderr)
11661*da0073e9SAndroid Build Coastguard Worker
11662*da0073e9SAndroid Build Coastguard Worker
11663*da0073e9SAndroid Build Coastguard Worker
11664*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_prob_target_all_reductions(self, device):
11665*da0073e9SAndroid Build Coastguard Worker        # Test with k-dimensional loss.
11666*da0073e9SAndroid Build Coastguard Worker        for k in range(5):
11667*da0073e9SAndroid Build Coastguard Worker            N, C = 5, 4
11668*da0073e9SAndroid Build Coastguard Worker            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11669*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11670*da0073e9SAndroid Build Coastguard Worker            target = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11671*da0073e9SAndroid Build Coastguard Worker            weight = torch.randn(C, device=device).abs()
11672*da0073e9SAndroid Build Coastguard Worker
11673*da0073e9SAndroid Build Coastguard Worker            for reduction, w in product(['none', 'mean', 'sum'], [None, weight]):
11674*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.CrossEntropyLoss(weight=w, reduction=reduction)
11675*da0073e9SAndroid Build Coastguard Worker                output = m(input, target)
11676*da0073e9SAndroid Build Coastguard Worker                output_ref = loss_reference_fns['CrossEntropyLoss'](
11677*da0073e9SAndroid Build Coastguard Worker                    input, target, reduction=reduction, weight=w)
11678*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, output_ref)
11679*da0073e9SAndroid Build Coastguard Worker
11680*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_prob_target_unit_weights(self, device):
11681*da0073e9SAndroid Build Coastguard Worker        # Test with k-dimensional loss.
11682*da0073e9SAndroid Build Coastguard Worker        for k in range(5):
11683*da0073e9SAndroid Build Coastguard Worker            N, C = 5, 4
11684*da0073e9SAndroid Build Coastguard Worker            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11685*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11686*da0073e9SAndroid Build Coastguard Worker            target = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11687*da0073e9SAndroid Build Coastguard Worker
11688*da0073e9SAndroid Build Coastguard Worker            for reduction in ['none', 'mean', 'sum']:
11689*da0073e9SAndroid Build Coastguard Worker                # Ensure result with unit weights is equivalent to result without weights.
11690*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.CrossEntropyLoss(reduction=reduction)
11691*da0073e9SAndroid Build Coastguard Worker                unit_weight = torch.ones(C, device=device, dtype=target.dtype)
11692*da0073e9SAndroid Build Coastguard Worker                m_unit = torch.nn.CrossEntropyLoss(weight=unit_weight, reduction=reduction)
11693*da0073e9SAndroid Build Coastguard Worker                output = m(input, target)
11694*da0073e9SAndroid Build Coastguard Worker                output_unit = m_unit(input, target)
11695*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, output_unit)
11696*da0073e9SAndroid Build Coastguard Worker
11697*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('reduction', ['none', 'mean', 'sum'])
11698*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('weighted', [False, True])
11699*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_prob_target_no_batch_dim(self, device, reduction, weighted):
11700*da0073e9SAndroid Build Coastguard Worker        C = 5
11701*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(C, device=device).log_softmax(dim=-1)
11702*da0073e9SAndroid Build Coastguard Worker        target = torch.randn(C, device=device).softmax(dim=-1)
11703*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(C, device=device) if weighted else None
11704*da0073e9SAndroid Build Coastguard Worker        m = nn.CrossEntropyLoss(reduction=reduction, weight=weight)
11705*da0073e9SAndroid Build Coastguard Worker        loss_no_batch = m(input, target)
11706*da0073e9SAndroid Build Coastguard Worker        loss_batch = m(input.unsqueeze(0), target.unsqueeze(0))
11707*da0073e9SAndroid Build Coastguard Worker        if reduction == 'none':
11708*da0073e9SAndroid Build Coastguard Worker            loss_batch = loss_batch.squeeze(0)
11709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loss_no_batch, loss_batch)
11710*da0073e9SAndroid Build Coastguard Worker
11711*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_index_target_unit_weights(self, device):
11712*da0073e9SAndroid Build Coastguard Worker        # Test with k-dimensional loss.
11713*da0073e9SAndroid Build Coastguard Worker        for k in range(5):
11714*da0073e9SAndroid Build Coastguard Worker            N, C = 5, 4
11715*da0073e9SAndroid Build Coastguard Worker            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11716*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11717*da0073e9SAndroid Build Coastguard Worker            target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C)
11718*da0073e9SAndroid Build Coastguard Worker
11719*da0073e9SAndroid Build Coastguard Worker            for reduction in ['none', 'mean', 'sum']:
11720*da0073e9SAndroid Build Coastguard Worker                # Ensure result with unit weights is equivalent to result without weights.
11721*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.CrossEntropyLoss(reduction=reduction)
11722*da0073e9SAndroid Build Coastguard Worker                unit_weight = torch.ones(C, device=device, dtype=input.dtype)
11723*da0073e9SAndroid Build Coastguard Worker                m_unit = torch.nn.CrossEntropyLoss(weight=unit_weight, reduction=reduction)
11724*da0073e9SAndroid Build Coastguard Worker                output = m(input, target)
11725*da0073e9SAndroid Build Coastguard Worker                output_unit = m_unit(input, target)
11726*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, output_unit)
11727*da0073e9SAndroid Build Coastguard Worker
11728*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss_one_hot_target(self, device):
11729*da0073e9SAndroid Build Coastguard Worker        # Test with k-dimensional loss.
11730*da0073e9SAndroid Build Coastguard Worker        for k in range(5):
11731*da0073e9SAndroid Build Coastguard Worker            N, C = 5, 4
11732*da0073e9SAndroid Build Coastguard Worker            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11733*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11734*da0073e9SAndroid Build Coastguard Worker            target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C)
11735*da0073e9SAndroid Build Coastguard Worker            weight = torch.randn(C, device=device).abs()
11736*da0073e9SAndroid Build Coastguard Worker
11737*da0073e9SAndroid Build Coastguard Worker            # Get one-hot representation of the target.
11738*da0073e9SAndroid Build Coastguard Worker            target_one_hot = F.one_hot(target, num_classes=C).to(input.dtype)
11739*da0073e9SAndroid Build Coastguard Worker            # Need to put the C dim at index 1.
11740*da0073e9SAndroid Build Coastguard Worker            target_one_hot = target_one_hot.permute(0, -1, *range(1, target_one_hot.dim() - 1))
11741*da0073e9SAndroid Build Coastguard Worker
11742*da0073e9SAndroid Build Coastguard Worker            for reduction, w in product(['none', 'mean', 'sum'], [None, weight]):
11743*da0073e9SAndroid Build Coastguard Worker                # Skip this case for now because soft and hard label CE are not consistent
11744*da0073e9SAndroid Build Coastguard Worker                # in the way they apply class weights (see issue #61309).
11745*da0073e9SAndroid Build Coastguard Worker                if reduction == 'mean' and weight is not None:
11746*da0073e9SAndroid Build Coastguard Worker                    continue
11747*da0073e9SAndroid Build Coastguard Worker
11748*da0073e9SAndroid Build Coastguard Worker                # Ensure loss computed with class indices matches loss
11749*da0073e9SAndroid Build Coastguard Worker                # computed with one-hot class probs.
11750*da0073e9SAndroid Build Coastguard Worker                m = torch.nn.CrossEntropyLoss(weight=w, reduction=reduction)
11751*da0073e9SAndroid Build Coastguard Worker                output = m(input, target)
11752*da0073e9SAndroid Build Coastguard Worker                output_one_hot = m(input, target_one_hot)
11753*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, output_one_hot)
11754*da0073e9SAndroid Build Coastguard Worker
11755*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_label_smoothing_errors(self, device):
11756*da0073e9SAndroid Build Coastguard Worker        N, C = 3, 4
11757*da0073e9SAndroid Build Coastguard Worker        input_args = [
11758*da0073e9SAndroid Build Coastguard Worker            (torch.randn((N, C), device=device), torch.arange(0, C, device=device)),
11759*da0073e9SAndroid Build Coastguard Worker            (torch.randn((N, C), device=device), torch.randn(N, C, device=device))
11760*da0073e9SAndroid Build Coastguard Worker        ]
11761*da0073e9SAndroid Build Coastguard Worker        for input_arg in input_args:
11762*da0073e9SAndroid Build Coastguard Worker            loss = nn.CrossEntropyLoss(label_smoothing=1.2)
11763*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError,
11764*da0073e9SAndroid Build Coastguard Worker                                        r"label_smoothing must be between 0\.0"):
11765*da0073e9SAndroid Build Coastguard Worker                loss(*input_arg)
11766*da0073e9SAndroid Build Coastguard Worker
11767*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
11768*da0073e9SAndroid Build Coastguard Worker    @set_default_dtype(torch.double)
11769*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device):
11770*da0073e9SAndroid Build Coastguard Worker        N, C = 10, 4
11771*da0073e9SAndroid Build Coastguard Worker        ks = range(5)
11772*da0073e9SAndroid Build Coastguard Worker        reductions = ['none', 'mean', 'sum']
11773*da0073e9SAndroid Build Coastguard Worker        label_smoothings = [0.05, 0.15]
11774*da0073e9SAndroid Build Coastguard Worker
11775*da0073e9SAndroid Build Coastguard Worker        for k, reduction, label_smoothing in product(ks, reductions, label_smoothings):
11776*da0073e9SAndroid Build Coastguard Worker            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11777*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11778*da0073e9SAndroid Build Coastguard Worker            target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C)
11779*da0073e9SAndroid Build Coastguard Worker
11780*da0073e9SAndroid Build Coastguard Worker            # construct target probablity that should have the same result as label_smoothing
11781*da0073e9SAndroid Build Coastguard Worker            target_proba = F.one_hot(target, num_classes=C)
11782*da0073e9SAndroid Build Coastguard Worker            # Need to put the C dim at index 1.
11783*da0073e9SAndroid Build Coastguard Worker            target_proba = target_proba.permute(0, -1, *range(1, target_proba.dim() - 1))
11784*da0073e9SAndroid Build Coastguard Worker            target_mask = (target_proba == 1)
11785*da0073e9SAndroid Build Coastguard Worker            target_proba = target_proba.to(dtype=input.dtype)
11786*da0073e9SAndroid Build Coastguard Worker
11787*da0073e9SAndroid Build Coastguard Worker            # y_k^ls = y_k * (1 - label_smoothing) + label_smoothing / n_classes
11788*da0073e9SAndroid Build Coastguard Worker            # Get one-hot representation of the target.
11789*da0073e9SAndroid Build Coastguard Worker            target_proba.masked_fill_(target_mask, 1 - label_smoothing + label_smoothing / C)
11790*da0073e9SAndroid Build Coastguard Worker            target_proba.masked_fill_(~target_mask, label_smoothing / C)
11791*da0073e9SAndroid Build Coastguard Worker
11792*da0073e9SAndroid Build Coastguard Worker            loss = nn.CrossEntropyLoss(reduction=reduction)
11793*da0073e9SAndroid Build Coastguard Worker            output_with_prob = loss(input, target_proba)
11794*da0073e9SAndroid Build Coastguard Worker
11795*da0073e9SAndroid Build Coastguard Worker            loss = nn.CrossEntropyLoss(
11796*da0073e9SAndroid Build Coastguard Worker                reduction=reduction, label_smoothing=label_smoothing)
11797*da0073e9SAndroid Build Coastguard Worker            output_with_index = loss(input, target)
11798*da0073e9SAndroid Build Coastguard Worker
11799*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_with_prob, output_with_index,
11800*da0073e9SAndroid Build Coastguard Worker                             rtol=1e-07, atol=1e-05)
11801*da0073e9SAndroid Build Coastguard Worker
11802*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_label_smoothing_with_probs(self, device):
11803*da0073e9SAndroid Build Coastguard Worker        N, C = 10, 4
11804*da0073e9SAndroid Build Coastguard Worker        ks = range(5)
11805*da0073e9SAndroid Build Coastguard Worker        reductions = ['none', 'mean', 'sum']
11806*da0073e9SAndroid Build Coastguard Worker        label_smoothings = [0.05, 0.15]
11807*da0073e9SAndroid Build Coastguard Worker
11808*da0073e9SAndroid Build Coastguard Worker        # Test with k-dimensional loss.
11809*da0073e9SAndroid Build Coastguard Worker        for k, label_smoothing in product(ks, label_smoothings):
11810*da0073e9SAndroid Build Coastguard Worker            other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)]
11811*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(N, C, *other_dims, device=device, requires_grad=True)
11812*da0073e9SAndroid Build Coastguard Worker            target = F.log_softmax(torch.randn(N, C, *other_dims, device=device), dim=1)
11813*da0073e9SAndroid Build Coastguard Worker
11814*da0073e9SAndroid Build Coastguard Worker            for reduction in reductions:
11815*da0073e9SAndroid Build Coastguard Worker                # use with label_smoothing
11816*da0073e9SAndroid Build Coastguard Worker                loss = nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing)
11817*da0073e9SAndroid Build Coastguard Worker                output_with_smoothing = loss(input, target)
11818*da0073e9SAndroid Build Coastguard Worker
11819*da0073e9SAndroid Build Coastguard Worker                # manually smoothing target
11820*da0073e9SAndroid Build Coastguard Worker                # class_proba^ls = class_proba * (1 - label_smoothing) +
11821*da0073e9SAndroid Build Coastguard Worker                #                  label_smoothing / n_classes
11822*da0073e9SAndroid Build Coastguard Worker                target_with_smoothing = target * (1 - label_smoothing) + label_smoothing / C
11823*da0073e9SAndroid Build Coastguard Worker                loss = nn.CrossEntropyLoss(reduction=reduction)
11824*da0073e9SAndroid Build Coastguard Worker                output_with_manual_smoothing = loss(input, target_with_smoothing)
11825*da0073e9SAndroid Build Coastguard Worker
11826*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output_with_smoothing, output_with_manual_smoothing)
11827*da0073e9SAndroid Build Coastguard Worker
11828*da0073e9SAndroid Build Coastguard Worker
11829*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_label_smoothing_weight_ignore_indices(self, device):
11830*da0073e9SAndroid Build Coastguard Worker        reductions = ['none', 'sum', 'mean']
11831*da0073e9SAndroid Build Coastguard Worker        label_smoothings = [0.05, 0.15]
11832*da0073e9SAndroid Build Coastguard Worker
11833*da0073e9SAndroid Build Coastguard Worker        wgt = torch.tensor([0.3, 0.6], device=device)
11834*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.tensor([[0.3, 0.4], [1, 2]], device=device)
11835*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.tensor([[0.3, 0.6], [1, 2]], device=device)
11836*da0073e9SAndroid Build Coastguard Worker
11837*da0073e9SAndroid Build Coastguard Worker        targ_default_ignore_index = torch.tensor([-100, 1], device=device)
11838*da0073e9SAndroid Build Coastguard Worker        targ_negative_ignore_index = torch.tensor([-2, 1], device=device)
11839*da0073e9SAndroid Build Coastguard Worker        targ_positive_ignore_index = torch.tensor([2, 1], device=device)
11840*da0073e9SAndroid Build Coastguard Worker
11841*da0073e9SAndroid Build Coastguard Worker        for reduction, label_smoothing, weight in product(reductions, label_smoothings, (None, wgt)):
11842*da0073e9SAndroid Build Coastguard Worker            def check_equal(loss, inp_targ_1, inp_targ_2):
11843*da0073e9SAndroid Build Coastguard Worker                inp1, targ1 = inp_targ_1
11844*da0073e9SAndroid Build Coastguard Worker                inp2, targ2 = inp_targ_2
11845*da0073e9SAndroid Build Coastguard Worker                l1 = loss(inp1, targ1)
11846*da0073e9SAndroid Build Coastguard Worker                l2 = loss(inp2, targ2)
11847*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(l1, l2)
11848*da0073e9SAndroid Build Coastguard Worker
11849*da0073e9SAndroid Build Coastguard Worker            # Default ignore_index
11850*da0073e9SAndroid Build Coastguard Worker            loss = nn.CrossEntropyLoss(reduction=reduction,
11851*da0073e9SAndroid Build Coastguard Worker                                       label_smoothing=label_smoothing,
11852*da0073e9SAndroid Build Coastguard Worker                                       weight=weight)
11853*da0073e9SAndroid Build Coastguard Worker            check_equal(loss, (inp1, targ_default_ignore_index), (inp2, targ_default_ignore_index))
11854*da0073e9SAndroid Build Coastguard Worker            if reduction != 'none':
11855*da0073e9SAndroid Build Coastguard Worker                # Check that we correctly tally the denominator for `mean`
11856*da0073e9SAndroid Build Coastguard Worker                # i.e. we don't count the ignored_idx at all.
11857*da0073e9SAndroid Build Coastguard Worker                check_equal(loss, (inp1, targ_default_ignore_index), (inp2[1:], targ_default_ignore_index[1:]))
11858*da0073e9SAndroid Build Coastguard Worker
11859*da0073e9SAndroid Build Coastguard Worker            # negative ignore_index
11860*da0073e9SAndroid Build Coastguard Worker            loss = nn.CrossEntropyLoss(reduction=reduction,
11861*da0073e9SAndroid Build Coastguard Worker                                       label_smoothing=label_smoothing,
11862*da0073e9SAndroid Build Coastguard Worker                                       ignore_index=-2,
11863*da0073e9SAndroid Build Coastguard Worker                                       weight=weight)
11864*da0073e9SAndroid Build Coastguard Worker            check_equal(loss, (inp1, targ_negative_ignore_index), (inp2, targ_negative_ignore_index))
11865*da0073e9SAndroid Build Coastguard Worker            if reduction != 'none':
11866*da0073e9SAndroid Build Coastguard Worker                # Check that we correctly tally the denominator for `mean`
11867*da0073e9SAndroid Build Coastguard Worker                # i.e. we don't count the ignored_idx at all.
11868*da0073e9SAndroid Build Coastguard Worker                check_equal(loss, (inp1, targ_negative_ignore_index), (inp2[1:], targ_negative_ignore_index[1:]))
11869*da0073e9SAndroid Build Coastguard Worker
11870*da0073e9SAndroid Build Coastguard Worker            # positive ignore_index
11871*da0073e9SAndroid Build Coastguard Worker            loss = nn.CrossEntropyLoss(reduction=reduction,
11872*da0073e9SAndroid Build Coastguard Worker                                       label_smoothing=label_smoothing,
11873*da0073e9SAndroid Build Coastguard Worker                                       ignore_index=2,
11874*da0073e9SAndroid Build Coastguard Worker                                       weight=weight)
11875*da0073e9SAndroid Build Coastguard Worker            check_equal(loss, (inp1, targ_positive_ignore_index), (inp2, targ_positive_ignore_index))
11876*da0073e9SAndroid Build Coastguard Worker            if reduction != 'none':
11877*da0073e9SAndroid Build Coastguard Worker                # Check that we correctly tally the denominator for `mean`
11878*da0073e9SAndroid Build Coastguard Worker                # i.e. we don't count the ignored_idx at all.
11879*da0073e9SAndroid Build Coastguard Worker                check_equal(loss, (inp1, targ_positive_ignore_index), (inp2[1:], targ_positive_ignore_index[1:]))
11880*da0073e9SAndroid Build Coastguard Worker
11881*da0073e9SAndroid Build Coastguard Worker    # Ref: https://github.com/pytorch/pytorch/issues/85005
11882*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11883*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("45GB", "cpu")
11884*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("70GB", "cuda")
11885*da0073e9SAndroid Build Coastguard Worker    @parametrize_test("reduction", ("none", "mean", "sum"))
11886*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_large_tensor(self, device, reduction):
11887*da0073e9SAndroid Build Coastguard Worker        logits = torch.randn(int(2 ** 16), int(2 ** 16) + 1, dtype=torch.float32, device='cuda', requires_grad=True)
11888*da0073e9SAndroid Build Coastguard Worker        labels = torch.zeros(logits.size(0), dtype=torch.long, device='cuda')
11889*da0073e9SAndroid Build Coastguard Worker        loss = F.cross_entropy(logits, labels, reduction=reduction)
11890*da0073e9SAndroid Build Coastguard Worker        if reduction != "none":
11891*da0073e9SAndroid Build Coastguard Worker            loss.backward()
11892*da0073e9SAndroid Build Coastguard Worker
11893*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
11894*da0073e9SAndroid Build Coastguard Worker            logits_cpu = logits.cpu().detach().requires_grad_()
11895*da0073e9SAndroid Build Coastguard Worker            labels_cpu = labels.cpu().detach()
11896*da0073e9SAndroid Build Coastguard Worker        loss_cpu = F.cross_entropy(logits_cpu, labels_cpu, reduction=reduction)
11897*da0073e9SAndroid Build Coastguard Worker        if reduction != "none":
11898*da0073e9SAndroid Build Coastguard Worker            loss_cpu.backward()
11899*da0073e9SAndroid Build Coastguard Worker
11900*da0073e9SAndroid Build Coastguard Worker        # workaround to reduce memory usage vs. self.assertEqual, see #84944
11901*da0073e9SAndroid Build Coastguard Worker        rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
11902*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(loss.cpu(), loss_cpu, rtol=rtol, atol=atol))
11903*da0073e9SAndroid Build Coastguard Worker        if reduction != "none":
11904*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(logits.grad.cpu(), logits_cpu.grad, rtol=rtol, atol=atol))
11905*da0073e9SAndroid Build Coastguard Worker
11906*da0073e9SAndroid Build Coastguard Worker    def test_smoothl1loss_backward_zero_beta(self, device):
11907*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(300, 256, requires_grad=True, device=device)
11908*da0073e9SAndroid Build Coastguard Worker        target = input.detach()
11909*da0073e9SAndroid Build Coastguard Worker
11910*da0073e9SAndroid Build Coastguard Worker        loss = F.smooth_l1_loss(input, target, beta=0.0, reduction='sum')
11911*da0073e9SAndroid Build Coastguard Worker        loss.backward()
11912*da0073e9SAndroid Build Coastguard Worker
11913*da0073e9SAndroid Build Coastguard Worker        grad_max_abs = input.grad.abs().max().item()
11914*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(grad_max_abs, 1.0)
11915*da0073e9SAndroid Build Coastguard Worker
11916*da0073e9SAndroid Build Coastguard Worker    def test_softshrink_negative(self, device):
11917*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5, device=device, requires_grad=True)
11918*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Softshrink(-1)
11919*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
11920*da0073e9SAndroid Build Coastguard Worker                                    r'lambda must be greater or equal to 0, but found to be -1\.'):
11921*da0073e9SAndroid Build Coastguard Worker            m(input)
11922*da0073e9SAndroid Build Coastguard Worker
11923*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
11924*da0073e9SAndroid Build Coastguard Worker    def test_fold(self, device):
11925*da0073e9SAndroid Build Coastguard Worker        def test_dtype(fn, input, dtype):
11926*da0073e9SAndroid Build Coastguard Worker            input = input.detach().clone().to(dtype=dtype).requires_grad_(True)
11927*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().float().requires_grad_(True)
11928*da0073e9SAndroid Build Coastguard Worker            out = fn(input)
11929*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
11930*da0073e9SAndroid Build Coastguard Worker            out2 = fn(input2)
11931*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
11932*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.dtype, dtype)
11933*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.dtype, dtype)
11934*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2.to(dtype=dtype), atol=0.05, rtol=0)
11935*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
11936*da0073e9SAndroid Build Coastguard Worker
11937*da0073e9SAndroid Build Coastguard Worker        def func(x):
11938*da0073e9SAndroid Build Coastguard Worker            return F.fold(x, output_size=(4, 5), kernel_size=(2, 2))
11939*da0073e9SAndroid Build Coastguard Worker
11940*da0073e9SAndroid Build Coastguard Worker        seeds = (44, 83, 71, 25, 999)
11941*da0073e9SAndroid Build Coastguard Worker        for sd in seeds:
11942*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(sd)
11943*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, 12, 12, device=device, requires_grad=True, dtype=torch.double)
11944*da0073e9SAndroid Build Coastguard Worker            gradcheck(func, [x], check_forward_ad=True)
11945*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(func, [x], check_fwd_over_rev=True)
11946*da0073e9SAndroid Build Coastguard Worker            if device == 'cpu':
11947*da0073e9SAndroid Build Coastguard Worker                test_dtype(func, x, torch.bfloat16)
11948*da0073e9SAndroid Build Coastguard Worker
11949*da0073e9SAndroid Build Coastguard Worker
11950*da0073e9SAndroid Build Coastguard Worker    def test_logsigmoid_out(self, device):
11951*da0073e9SAndroid Build Coastguard Worker        # this isn't actually documented, but was broken previously:
11952*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/36499
11953*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device=device).t()
11954*da0073e9SAndroid Build Coastguard Worker        empty_out = torch.randn(0, device=device)
11955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.logsigmoid(x), F.logsigmoid(x, out=empty_out))
11956*da0073e9SAndroid Build Coastguard Worker
11957*da0073e9SAndroid Build Coastguard Worker        noncontig_out = torch.randn(2, 3, device=device).t()
11958*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.logsigmoid(x), F.logsigmoid(x, out=noncontig_out))
11959*da0073e9SAndroid Build Coastguard Worker
11960*da0073e9SAndroid Build Coastguard Worker    # Check that clip_grad_norm_ raises an error if the total norm of the
11961*da0073e9SAndroid Build Coastguard Worker    # parameters' gradients is non-finite
11962*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
11963*da0073e9SAndroid Build Coastguard Worker    def test_clip_grad_norm_error_if_nonfinite(self, device):
11964*da0073e9SAndroid Build Coastguard Worker        norms_pos = [0.1, 1, 2, 3.5, inf]
11965*da0073e9SAndroid Build Coastguard Worker        norms_neg = [-0.1, -1, -2, -3.5]
11966*da0073e9SAndroid Build Coastguard Worker        norms_except_0 = norms_pos + norms_neg
11967*da0073e9SAndroid Build Coastguard Worker        norms_all = norms_except_0 + [0]
11968*da0073e9SAndroid Build Coastguard Worker
11969*da0073e9SAndroid Build Coastguard Worker        # Each entry in test_cases has the following values, in this order:
11970*da0073e9SAndroid Build Coastguard Worker        #
11971*da0073e9SAndroid Build Coastguard Worker        # grad_only_one_elem    If True, only one element of the parameter's
11972*da0073e9SAndroid Build Coastguard Worker        #                       gradient is set to the scalar grad, and the
11973*da0073e9SAndroid Build Coastguard Worker        #                       rest of the elements are 0. If False, all grad
11974*da0073e9SAndroid Build Coastguard Worker        #                       elements are equal to the scalar.
11975*da0073e9SAndroid Build Coastguard Worker        #
11976*da0073e9SAndroid Build Coastguard Worker        # prefix_finite_grad_param  If True, prefix a parameter that has a grad
11977*da0073e9SAndroid Build Coastguard Worker        #                           of 1.
11978*da0073e9SAndroid Build Coastguard Worker        #
11979*da0073e9SAndroid Build Coastguard Worker        # scalars           Scalars to use as the parameter's grad, through
11980*da0073e9SAndroid Build Coastguard Worker        #                   multiplication
11981*da0073e9SAndroid Build Coastguard Worker        #
11982*da0073e9SAndroid Build Coastguard Worker        # norms_nonfinite   Norm types that should produce nonfinite total norm
11983*da0073e9SAndroid Build Coastguard Worker        #
11984*da0073e9SAndroid Build Coastguard Worker        # norms_finite      Norm types that should produce finite total norm
11985*da0073e9SAndroid Build Coastguard Worker        test_cases = [
11986*da0073e9SAndroid Build Coastguard Worker            # Test errors from an infinite grad
11987*da0073e9SAndroid Build Coastguard Worker            (False, False, [inf, -inf], norms_except_0, [0]),
11988*da0073e9SAndroid Build Coastguard Worker            (False, True, [inf, -inf], norms_pos, norms_neg + [0]),
11989*da0073e9SAndroid Build Coastguard Worker            (True, False, [inf, -inf], norms_pos, norms_neg + [0]),
11990*da0073e9SAndroid Build Coastguard Worker            (True, True, [inf, -inf], norms_pos, norms_neg + [0]),
11991*da0073e9SAndroid Build Coastguard Worker
11992*da0073e9SAndroid Build Coastguard Worker            # Test errors from a NaN grad
11993*da0073e9SAndroid Build Coastguard Worker            (False, False, [nan], norms_except_0, [0]),
11994*da0073e9SAndroid Build Coastguard Worker            (False, True, [nan], norms_except_0, [0]),
11995*da0073e9SAndroid Build Coastguard Worker            (True, False, [nan], norms_except_0, [0]),
11996*da0073e9SAndroid Build Coastguard Worker            (True, True, [nan], norms_except_0, [0]),
11997*da0073e9SAndroid Build Coastguard Worker
11998*da0073e9SAndroid Build Coastguard Worker            # Test a grad that should never error
11999*da0073e9SAndroid Build Coastguard Worker            (False, False, [2e22, -2e22], [], norms_all),
12000*da0073e9SAndroid Build Coastguard Worker            (False, True, [2e22, -2e22], [], norms_all),
12001*da0073e9SAndroid Build Coastguard Worker            (True, False, [2e22, -2e22], [], norms_all),
12002*da0073e9SAndroid Build Coastguard Worker            (True, True, [2e22, -2e22], [], norms_all),
12003*da0073e9SAndroid Build Coastguard Worker
12004*da0073e9SAndroid Build Coastguard Worker            # Test a grad that will overflow to inf for only some norm orders
12005*da0073e9SAndroid Build Coastguard Worker            (False, False, [2e200, -2e200], [3.5, 2, -2, -3.5], [inf, 1, 0.1, 0, -1, -0.1]),
12006*da0073e9SAndroid Build Coastguard Worker            (False, True, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]),
12007*da0073e9SAndroid Build Coastguard Worker            (True, False, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]),
12008*da0073e9SAndroid Build Coastguard Worker            (True, True, [2e200, -2e200], [3.5, 2], norms_neg + [inf, 1, 0.1, 0]),
12009*da0073e9SAndroid Build Coastguard Worker        ]
12010*da0073e9SAndroid Build Coastguard Worker
12011*da0073e9SAndroid Build Coastguard Worker        def gen_parameters(scalar, grad_only_one_elem, prefix_finite_grad_param):
12012*da0073e9SAndroid Build Coastguard Worker            param = torch.ones(10, dtype=torch.float64, device=device, requires_grad=True)
12013*da0073e9SAndroid Build Coastguard Worker
12014*da0073e9SAndroid Build Coastguard Worker            if grad_only_one_elem:
12015*da0073e9SAndroid Build Coastguard Worker                param[1].mul(scalar).sum().backward()
12016*da0073e9SAndroid Build Coastguard Worker            else:
12017*da0073e9SAndroid Build Coastguard Worker                param.mul(scalar).sum().backward()
12018*da0073e9SAndroid Build Coastguard Worker
12019*da0073e9SAndroid Build Coastguard Worker            if prefix_finite_grad_param:
12020*da0073e9SAndroid Build Coastguard Worker                prefix_param = torch.ones(1, dtype=torch.float64, device=device, requires_grad=True)
12021*da0073e9SAndroid Build Coastguard Worker                prefix_param.mul(1).sum().backward()
12022*da0073e9SAndroid Build Coastguard Worker                parameters = [prefix_param, param]
12023*da0073e9SAndroid Build Coastguard Worker            else:
12024*da0073e9SAndroid Build Coastguard Worker                parameters = [param]
12025*da0073e9SAndroid Build Coastguard Worker
12026*da0073e9SAndroid Build Coastguard Worker            return parameters
12027*da0073e9SAndroid Build Coastguard Worker
12028*da0073e9SAndroid Build Coastguard Worker        def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, is_norm_nonfinite):
12029*da0073e9SAndroid Build Coastguard Worker            msg = (
12030*da0073e9SAndroid Build Coastguard Worker                f'norm_type: {norm_type}, ',
12031*da0073e9SAndroid Build Coastguard Worker                f'error_if_nonfinite: {error_if_nonfinite}, '
12032*da0073e9SAndroid Build Coastguard Worker                f'scalar: {scalar}, '
12033*da0073e9SAndroid Build Coastguard Worker                f'grad_only_one_elem: {grad_only_one_elem}, '
12034*da0073e9SAndroid Build Coastguard Worker                f'prefix_finite_grad_param: {prefix_finite_grad_param}, '
12035*da0073e9SAndroid Build Coastguard Worker                f'is_norm_nonfinite: {is_norm_nonfinite}')
12036*da0073e9SAndroid Build Coastguard Worker
12037*da0073e9SAndroid Build Coastguard Worker            parameters = gen_parameters(scalar, grad_only_one_elem, prefix_finite_grad_param)
12038*da0073e9SAndroid Build Coastguard Worker
12039*da0073e9SAndroid Build Coastguard Worker            # Should only throw an error if the total norm is expected to be
12040*da0073e9SAndroid Build Coastguard Worker            # nonfinite and `error_if_nonfinite=True`
12041*da0073e9SAndroid Build Coastguard Worker            if is_norm_nonfinite and error_if_nonfinite:
12042*da0073e9SAndroid Build Coastguard Worker                error_msg = f'The total norm of order {float(norm_type)} for gradients'
12043*da0073e9SAndroid Build Coastguard Worker
12044*da0073e9SAndroid Build Coastguard Worker                grads_before = [p.grad.clone() for p in parameters]
12045*da0073e9SAndroid Build Coastguard Worker
12046*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, error_msg, msg=msg):
12047*da0073e9SAndroid Build Coastguard Worker                    clip_grad_norm_(parameters, 1, norm_type=norm_type, error_if_nonfinite=True)
12048*da0073e9SAndroid Build Coastguard Worker
12049*da0073e9SAndroid Build Coastguard Worker                # Grad should not change if error is thrown
12050*da0073e9SAndroid Build Coastguard Worker                grads_after = [p.grad for p in parameters]
12051*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grads_before, grads_after, msg=msg)
12052*da0073e9SAndroid Build Coastguard Worker            else:
12053*da0073e9SAndroid Build Coastguard Worker                clip_grad_norm_(parameters, 1, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite)
12054*da0073e9SAndroid Build Coastguard Worker
12055*da0073e9SAndroid Build Coastguard Worker        for grad_only_one_elem, prefix_finite_grad_param, scalars, norms_nonfinite, norms_finite in test_cases:
12056*da0073e9SAndroid Build Coastguard Worker            for error_if_nonfinite in [False, True]:
12057*da0073e9SAndroid Build Coastguard Worker                for norm_type, scalar in product(norms_nonfinite, scalars):
12058*da0073e9SAndroid Build Coastguard Worker                    run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, True)
12059*da0073e9SAndroid Build Coastguard Worker
12060*da0073e9SAndroid Build Coastguard Worker                for norm_type, scalar in product(norms_finite, scalars):
12061*da0073e9SAndroid Build Coastguard Worker                    run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, prefix_finite_grad_param, False)
12062*da0073e9SAndroid Build Coastguard Worker
12063*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
12064*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
12065*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('foreach', (False, True))
12066*da0073e9SAndroid Build Coastguard Worker    def test_clip_grad_norm_multi_device(self, devices, foreach):
12067*da0073e9SAndroid Build Coastguard Worker        class TestModel(nn.Module):
12068*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12069*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12070*da0073e9SAndroid Build Coastguard Worker                self.layer1 = nn.Linear(10, 10)
12071*da0073e9SAndroid Build Coastguard Worker                self.layer2 = nn.Linear(10, 10)
12072*da0073e9SAndroid Build Coastguard Worker
12073*da0073e9SAndroid Build Coastguard Worker        test_model = TestModel()
12074*da0073e9SAndroid Build Coastguard Worker        test_model.layer1.to(devices[0])
12075*da0073e9SAndroid Build Coastguard Worker        test_model.layer2.to(devices[1])
12076*da0073e9SAndroid Build Coastguard Worker        ref_model = TestModel().to(devices[0])
12077*da0073e9SAndroid Build Coastguard Worker        for norm_type in [2., math.inf]:
12078*da0073e9SAndroid Build Coastguard Worker            for p in test_model.parameters():
12079*da0073e9SAndroid Build Coastguard Worker                p.grad = torch.ones_like(p)
12080*da0073e9SAndroid Build Coastguard Worker            for p in ref_model.parameters():
12081*da0073e9SAndroid Build Coastguard Worker                p.grad = torch.ones_like(p)
12082*da0073e9SAndroid Build Coastguard Worker            norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
12083*da0073e9SAndroid Build Coastguard Worker            expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
12084*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(norm, expected)
12085*da0073e9SAndroid Build Coastguard Worker            for p, pe in zip(test_model.parameters(), ref_model.parameters()):
12086*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p.grad.to(devices[0]), pe.grad)
12087*da0073e9SAndroid Build Coastguard Worker
12088*da0073e9SAndroid Build Coastguard Worker    def test_elu_inplace_overlap(self, device):
12089*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16 if device != 'mps:0' else torch.float16
12090*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), dtype=dtype, device=device).expand((6, 6))
12091*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12092*da0073e9SAndroid Build Coastguard Worker            F.elu(x, inplace=True)
12093*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12094*da0073e9SAndroid Build Coastguard Worker            F.elu_(x)
12095*da0073e9SAndroid Build Coastguard Worker
12096*da0073e9SAndroid Build Coastguard Worker    # Merge into OpInfo?
12097*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
12098*da0073e9SAndroid Build Coastguard Worker    def test_elu_inplace_with_neg_alpha(self, device):
12099*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12100*da0073e9SAndroid Build Coastguard Worker        b = torch.nn.functional.elu_(a.clone(), alpha=-2)
12101*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12102*da0073e9SAndroid Build Coastguard Worker            b.backward(torch.ones(2, device=device))
12103*da0073e9SAndroid Build Coastguard Worker
12104*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12105*da0073e9SAndroid Build Coastguard Worker        b = torch.nn.functional.celu_(a.clone(), alpha=-2)
12106*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12107*da0073e9SAndroid Build Coastguard Worker            b.backward(torch.ones(2, device=device))
12108*da0073e9SAndroid Build Coastguard Worker
12109*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # https://github.com/pytorch/pytorch/issues/54897
12110*da0073e9SAndroid Build Coastguard Worker    def test_hardswish_inplace_overlap(self, device):
12111*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), device=device).expand((6, 6))
12112*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12113*da0073e9SAndroid Build Coastguard Worker            F.hardswish(x, inplace=True)
12114*da0073e9SAndroid Build Coastguard Worker
12115*da0073e9SAndroid Build Coastguard Worker    def test_silu_inplace_overlap(self, device):
12116*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), device=device).expand((6, 6))
12117*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12118*da0073e9SAndroid Build Coastguard Worker            F.silu(x, inplace=True)
12119*da0073e9SAndroid Build Coastguard Worker
12120*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
12121*da0073e9SAndroid Build Coastguard Worker    def test_mish_inplace_overlap(self, device):
12122*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), device=device).expand((6, 6))
12123*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12124*da0073e9SAndroid Build Coastguard Worker            F.mish(x, inplace=True)
12125*da0073e9SAndroid Build Coastguard Worker
12126*da0073e9SAndroid Build Coastguard Worker    def test_softplus_inplace_overlap(self, device):
12127*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), device=device).expand((6, 6))
12128*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12129*da0073e9SAndroid Build Coastguard Worker            F.softplus(x, out=x)
12130*da0073e9SAndroid Build Coastguard Worker
12131*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # TypeError: the MPS framework doesn't support float64
12132*da0073e9SAndroid Build Coastguard Worker    def test_softplus_low_threshold(self, device):
12133*da0073e9SAndroid Build Coastguard Worker        # Ensure gradients are computed correctly with a low threshold.
12134*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Softplus(threshold=1).double()
12135*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor(0.9, device=device, dtype=torch.double,
12136*da0073e9SAndroid Build Coastguard Worker                             requires_grad=True)
12137*da0073e9SAndroid Build Coastguard Worker        output = model(input)
12138*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(model, input)
12139*da0073e9SAndroid Build Coastguard Worker
12140*da0073e9SAndroid Build Coastguard Worker    def test_softshrink_inplace_overlap(self, device):
12141*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), device=device).expand((6, 6))
12142*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12143*da0073e9SAndroid Build Coastguard Worker            F.softshrink(x, out=x)
12144*da0073e9SAndroid Build Coastguard Worker
12145*da0073e9SAndroid Build Coastguard Worker    def test_leaky_relu_inplace_overlap(self, device):
12146*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), device=device).expand((6, 6))
12147*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12148*da0073e9SAndroid Build Coastguard Worker            F.leaky_relu(x, inplace=True)
12149*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
12150*da0073e9SAndroid Build Coastguard Worker            F.leaky_relu_(x)
12151*da0073e9SAndroid Build Coastguard Worker
12152*da0073e9SAndroid Build Coastguard Worker    # Merge into OpInfo?
12153*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764
12154*da0073e9SAndroid Build Coastguard Worker    def test_leaky_relu_inplace_with_neg_slope(self, device):
12155*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12156*da0073e9SAndroid Build Coastguard Worker        b = torch.nn.functional.leaky_relu_(a.clone(), -2)
12157*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12158*da0073e9SAndroid Build Coastguard Worker            b.backward(torch.ones(2, device=device))
12159*da0073e9SAndroid Build Coastguard Worker
12160*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([-1., 1.], device=device, requires_grad=True)
12161*da0073e9SAndroid Build Coastguard Worker        b = torch.nn.functional.rrelu_(a.clone(), -5.0, 1.0)
12162*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
12163*da0073e9SAndroid Build Coastguard Worker            b.backward(torch.ones(2, device=device))
12164*da0073e9SAndroid Build Coastguard Worker
12165*da0073e9SAndroid Build Coastguard Worker    # Merge into OpInfo?
12166*da0073e9SAndroid Build Coastguard Worker    def test_leaky_relu_inplace_with_zero_slope(self, device):
12167*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True)
12168*da0073e9SAndroid Build Coastguard Worker        b = torch.nn.functional.leaky_relu_(a.clone(), 0.0)
12169*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.ones(3, device=device))
12170*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([0., 0., 1.], device=device)
12171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, expected)
12172*da0073e9SAndroid Build Coastguard Worker
12173*da0073e9SAndroid Build Coastguard Worker        dtype = torch.bfloat16 if device != 'mps:0' else torch.float16
12174*da0073e9SAndroid Build Coastguard Worker        a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=dtype, requires_grad=True)
12175*da0073e9SAndroid Build Coastguard Worker        b_bf16 = torch.nn.functional.leaky_relu_(a_bf16.clone(), 0.0)
12176*da0073e9SAndroid Build Coastguard Worker        b_bf16.backward(torch.ones(3, device=device))
12177*da0073e9SAndroid Build Coastguard Worker        expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=dtype)
12178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_bf16.grad, expected_bf16)
12179*da0073e9SAndroid Build Coastguard Worker
12180*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
12181*da0073e9SAndroid Build Coastguard Worker    def test_softshrink(self, device):
12182*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1.21, 0.56, 0.5001, 0.4999, 1.2357, -0.4999, -0.5001, -1.154,
12183*da0073e9SAndroid Build Coastguard Worker                           0.254, -0.24, -0.225, 0.104, 0.002, -0.001, 0.0574, 1.2344,
12184*da0073e9SAndroid Build Coastguard Worker                           0.1748, -0.1797, -0.8125, 0.2051, -1.1328, 1.2344, -0.1562, 2.3554,
12185*da0073e9SAndroid Build Coastguard Worker                           -0.1953, 0.0304, -0.3613, -1.3047, 1.0312, 0.1436, -0.6953, 0.5664,
12186*da0073e9SAndroid Build Coastguard Worker                           -0.5820, -0.3301, 0.8203, 0.6133, 0.5938],
12187*da0073e9SAndroid Build Coastguard Worker                          [-0.8203, -1.2344, -0.5234, 2.5312, -0.4551, -0.6875, -1.5547, -0.2217,
12188*da0073e9SAndroid Build Coastguard Worker                           -0.3027, 2.6406, 1.3047, 0.2344, -1.6719, 0.2773, -1.3516, 3.4575,
12189*da0073e9SAndroid Build Coastguard Worker                           0.4414, 0.2656, 2.1094, -1.5156, 1.2344, -0.4336, 0.6797, -3.5486,
12190*da0073e9SAndroid Build Coastguard Worker                           0.9766, -0.4062, 1.4844, 0.7500, -1.7578, 0.7461, 1.6094, 8.5458,
12191*da0073e9SAndroid Build Coastguard Worker                           0.3730, -0.3477, -1.0625, 0.3848, 0.0557]], device=device)
12192*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([[0.71, 0.06, 0.0001, 0., 0.7357, 0., -0.0001, -0.654,
12193*da0073e9SAndroid Build Coastguard Worker                                  0., 0., 0., 0., 0., 0., 0., 0.7344,
12194*da0073e9SAndroid Build Coastguard Worker                                  0., 0., -0.3125, 0., -0.6328, 0.7344, 0., 1.8554,
12195*da0073e9SAndroid Build Coastguard Worker                                  0., 0., 0., -0.8047, 0.5312, 0., -0.1953, 0.0664,
12196*da0073e9SAndroid Build Coastguard Worker                                  -0.0820, 0.0, 0.3203, 0.1133, 0.0938],
12197*da0073e9SAndroid Build Coastguard Worker                                 [-0.3203, -0.7344, -0.0234, 2.0312, 0.0, -0.1875, -1.0547, 0.,
12198*da0073e9SAndroid Build Coastguard Worker                                  0.0, 2.1406, 0.8047, 0., -1.1719, 0., -0.8516, 2.9575,
12199*da0073e9SAndroid Build Coastguard Worker                                  0., 0., 1.6094, -1.0156, 0.7344, 0., 0.1797, -3.0486,
12200*da0073e9SAndroid Build Coastguard Worker                                  0.4766, 0., 0.9844, 0.2500, -1.2578, 0.2461, 1.1094, 8.0458,
12201*da0073e9SAndroid Build Coastguard Worker                                  0., 0., -0.5625, 0., 0.]])
12202*da0073e9SAndroid Build Coastguard Worker        softshrink = torch.nn.Softshrink()
12203*da0073e9SAndroid Build Coastguard Worker        out = softshrink(x)
12204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected, atol=1e-2, rtol=0)
12205*da0073e9SAndroid Build Coastguard Worker
12206*da0073e9SAndroid Build Coastguard Worker    def test_threshold_inplace_overlap(self, device):
12207*da0073e9SAndroid Build Coastguard Worker        # Inplace threshold is okay, because it is idempotent
12208*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 6), device=device).expand((6, 6))
12209*da0073e9SAndroid Build Coastguard Worker        F.threshold(x, 0.5, 0.5, inplace=True)
12210*da0073e9SAndroid Build Coastguard Worker        F.threshold_(x, 0.5, 0.5)
12211*da0073e9SAndroid Build Coastguard Worker
12212*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
12213*da0073e9SAndroid Build Coastguard Worker    def test_triplet_margin_with_distance_loss_default_parity(self, device):
12214*da0073e9SAndroid Build Coastguard Worker        # Test for `nn.TripletMarginWithDistanceLoss` and
12215*da0073e9SAndroid Build Coastguard Worker        # `F.triplet_margin_with_distance_loss`.  Checks
12216*da0073e9SAndroid Build Coastguard Worker        # for parity against the respective non-distance-agnostic
12217*da0073e9SAndroid Build Coastguard Worker        # implementations of triplet margin loss (``nn.TripletMarginLoss`
12218*da0073e9SAndroid Build Coastguard Worker        # and `F.triplet_margin_loss`) under *default args*.
12219*da0073e9SAndroid Build Coastguard Worker
12220*da0073e9SAndroid Build Coastguard Worker        for extra_args in \
12221*da0073e9SAndroid Build Coastguard Worker                itertools.product((0.5, 1, 1.5), (True, False), ('none', 'mean', 'sum')):
12222*da0073e9SAndroid Build Coastguard Worker            kwargs = {'margin': extra_args[0], 'swap': extra_args[1], 'reduction': extra_args[2]}
12223*da0073e9SAndroid Build Coastguard Worker
12224*da0073e9SAndroid Build Coastguard Worker            anchor = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12225*da0073e9SAndroid Build Coastguard Worker            positive = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12226*da0073e9SAndroid Build Coastguard Worker            negative = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12227*da0073e9SAndroid Build Coastguard Worker
12228*da0073e9SAndroid Build Coastguard Worker            # Test forward, functional
12229*da0073e9SAndroid Build Coastguard Worker            expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs)
12230*da0073e9SAndroid Build Coastguard Worker            actual = F.triplet_margin_with_distance_loss(anchor, positive, negative, **kwargs)
12231*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6)
12232*da0073e9SAndroid Build Coastguard Worker
12233*da0073e9SAndroid Build Coastguard Worker            # Test forward, module
12234*da0073e9SAndroid Build Coastguard Worker            loss_ref = nn.TripletMarginLoss(**kwargs)
12235*da0073e9SAndroid Build Coastguard Worker            loss_op = nn.TripletMarginWithDistanceLoss(**kwargs)
12236*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(loss_op(anchor, positive, negative),
12237*da0073e9SAndroid Build Coastguard Worker                             loss_ref(anchor, positive, negative),
12238*da0073e9SAndroid Build Coastguard Worker                             rtol=1e-6, atol=1e-6)
12239*da0073e9SAndroid Build Coastguard Worker
12240*da0073e9SAndroid Build Coastguard Worker            # Test backward
12241*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
12242*da0073e9SAndroid Build Coastguard Worker                a, p, n, **kwargs), (anchor, positive, negative)))
12243*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
12244*da0073e9SAndroid Build Coastguard Worker                            (anchor, positive, negative)))
12245*da0073e9SAndroid Build Coastguard Worker
12246*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
12247*da0073e9SAndroid Build Coastguard Worker    def test_triplet_margin_with_distance_loss(self, device):
12248*da0073e9SAndroid Build Coastguard Worker        # Test for parity between `nn.TripletMarginWithDistanceLoss` and
12249*da0073e9SAndroid Build Coastguard Worker        # `F.triplet_margin_with_distance_loss`.
12250*da0073e9SAndroid Build Coastguard Worker
12251*da0073e9SAndroid Build Coastguard Worker        pairwise_distance = nn.PairwiseDistance()
12252*da0073e9SAndroid Build Coastguard Worker
12253*da0073e9SAndroid Build Coastguard Worker        def cosine_distance(x, y):
12254*da0073e9SAndroid Build Coastguard Worker            return 1.0 - F.cosine_similarity(x, y)
12255*da0073e9SAndroid Build Coastguard Worker
12256*da0073e9SAndroid Build Coastguard Worker        distance_functions = (pairwise_distance, cosine_distance,
12257*da0073e9SAndroid Build Coastguard Worker                              lambda x, y: 1.0 - F.cosine_similarity(x, y))
12258*da0073e9SAndroid Build Coastguard Worker
12259*da0073e9SAndroid Build Coastguard Worker        reductions = ('mean', 'none', 'sum')
12260*da0073e9SAndroid Build Coastguard Worker        margins = (1.0, 1.5, 0.5)
12261*da0073e9SAndroid Build Coastguard Worker        swaps = (True, False)
12262*da0073e9SAndroid Build Coastguard Worker
12263*da0073e9SAndroid Build Coastguard Worker        for distance_fn, reduction, margin, swap \
12264*da0073e9SAndroid Build Coastguard Worker                in itertools.product(distance_functions, reductions, margins, swaps):
12265*da0073e9SAndroid Build Coastguard Worker            anchor = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12266*da0073e9SAndroid Build Coastguard Worker            positive = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12267*da0073e9SAndroid Build Coastguard Worker            negative = torch.randn(5, 10, device=device, requires_grad=True, dtype=torch.double)
12268*da0073e9SAndroid Build Coastguard Worker
12269*da0073e9SAndroid Build Coastguard Worker            # Test backward
12270*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
12271*da0073e9SAndroid Build Coastguard Worker                a, p, n, distance_function=distance_fn, reduction=reduction, margin=margin, swap=swap),
12272*da0073e9SAndroid Build Coastguard Worker                (anchor, positive, negative)))
12273*da0073e9SAndroid Build Coastguard Worker            loss_op = nn.TripletMarginWithDistanceLoss(distance_function=distance_fn,
12274*da0073e9SAndroid Build Coastguard Worker                                                       reduction=reduction, margin=margin, swap=swap)
12275*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(lambda a, p, n: loss_op(
12276*da0073e9SAndroid Build Coastguard Worker                a, p, n), (anchor, positive, negative)))
12277*da0073e9SAndroid Build Coastguard Worker            traced_loss_op = torch.jit.trace(loss_op, (anchor, positive, negative))
12278*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(lambda a, p, n: traced_loss_op(
12279*da0073e9SAndroid Build Coastguard Worker                a, p, n), (anchor, positive, negative)))
12280*da0073e9SAndroid Build Coastguard Worker
12281*da0073e9SAndroid Build Coastguard Worker            # Test forward parity
12282*da0073e9SAndroid Build Coastguard Worker            functional = F.triplet_margin_with_distance_loss(anchor, positive, negative,
12283*da0073e9SAndroid Build Coastguard Worker                                                             distance_function=distance_fn,
12284*da0073e9SAndroid Build Coastguard Worker                                                             reduction=reduction, margin=margin, swap=swap)
12285*da0073e9SAndroid Build Coastguard Worker            modular = loss_op(anchor, positive, negative)
12286*da0073e9SAndroid Build Coastguard Worker            traced = traced_loss_op(anchor, positive, negative)
12287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6)
12288*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6)
12289*da0073e9SAndroid Build Coastguard Worker
12290*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.cfloat, torch.float)
12291*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cfloat, torch.cdouble, torch.float)
12292*da0073e9SAndroid Build Coastguard Worker    def test_to_complex(self, device, dtype):
12293*da0073e9SAndroid Build Coastguard Worker        m = nn.Linear(3, 5).to(device)
12294*da0073e9SAndroid Build Coastguard Worker        self.assertIs(m, m.to(device))
12295*da0073e9SAndroid Build Coastguard Worker        m.to(dtype)
12296*da0073e9SAndroid Build Coastguard Worker        self.assertIs(m.weight.dtype, dtype)
12297*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
12298*da0073e9SAndroid Build Coastguard Worker            # Trigger warning
12299*da0073e9SAndroid Build Coastguard Worker            m.to(torch.cfloat)
12300*da0073e9SAndroid Build Coastguard Worker            # Check warning occurs
12301*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
12302*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("Complex modules are a new feature" in str(w[-1].message))
12303*da0073e9SAndroid Build Coastguard Worker
12304*da0073e9SAndroid Build Coastguard Worker    @skipMeta
12305*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.float32)
12306*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
12307*da0073e9SAndroid Build Coastguard Worker    def test_module_to_empty(self, device, dtype):
12308*da0073e9SAndroid Build Coastguard Worker        class MyModule(nn.Module):
12309*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_features, out_features, device=None, dtype=None):
12310*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12311*da0073e9SAndroid Build Coastguard Worker                factory_kwargs = {"device": device, "dtype": dtype}
12312*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(in_features, out_features, **factory_kwargs))
12313*da0073e9SAndroid Build Coastguard Worker
12314*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12315*da0073e9SAndroid Build Coastguard Worker                return x @ self.weight
12316*da0073e9SAndroid Build Coastguard Worker
12317*da0073e9SAndroid Build Coastguard Worker        # Test meta module instantiation.
12318*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5, 10, device=device, dtype=dtype)
12319*da0073e9SAndroid Build Coastguard Worker        m = MyModule(10, 1, device='meta', dtype=dtype)
12320*da0073e9SAndroid Build Coastguard Worker        m(input)
12321*da0073e9SAndroid Build Coastguard Worker
12322*da0073e9SAndroid Build Coastguard Worker        # Test empty meta module error with torch.nn.Module.to().
12323*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
12324*da0073e9SAndroid Build Coastguard Worker            NotImplementedError,
12325*da0073e9SAndroid Build Coastguard Worker            re.escape(
12326*da0073e9SAndroid Build Coastguard Worker                "Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() "
12327*da0073e9SAndroid Build Coastguard Worker                "instead of torch.nn.Module.to() when moving module from meta to a different "
12328*da0073e9SAndroid Build Coastguard Worker                "device."
12329*da0073e9SAndroid Build Coastguard Worker            ),
12330*da0073e9SAndroid Build Coastguard Worker        ):
12331*da0073e9SAndroid Build Coastguard Worker            m.to(device)
12332*da0073e9SAndroid Build Coastguard Worker
12333*da0073e9SAndroid Build Coastguard Worker        # Test materializing meta module on a real device.
12334*da0073e9SAndroid Build Coastguard Worker        m.to_empty(device=device)
12335*da0073e9SAndroid Build Coastguard Worker        m(input)
12336*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
12337*da0073e9SAndroid Build Coastguard Worker            torch.nn.init.kaiming_uniform_(m.weight)
12338*da0073e9SAndroid Build Coastguard Worker        m(input)
12339*da0073e9SAndroid Build Coastguard Worker
12340*da0073e9SAndroid Build Coastguard Worker        # Test creating meta module from materialized module.
12341*da0073e9SAndroid Build Coastguard Worker        m.to_empty(device='meta')
12342*da0073e9SAndroid Build Coastguard Worker        m(input)
12343*da0073e9SAndroid Build Coastguard Worker
12344*da0073e9SAndroid Build Coastguard Worker    def test_module_to_empty_non_recursive(self, device):
12345*da0073e9SAndroid Build Coastguard Worker        class Layer(nn.Module):
12346*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_features, out_features):
12347*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12348*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(in_features, out_features))
12349*da0073e9SAndroid Build Coastguard Worker                self.register_buffer('buf', torch.randn(out_features))
12350*da0073e9SAndroid Build Coastguard Worker
12351*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12352*da0073e9SAndroid Build Coastguard Worker                return x @ self.weight + self.buf
12353*da0073e9SAndroid Build Coastguard Worker
12354*da0073e9SAndroid Build Coastguard Worker        class MyModule(nn.Module):
12355*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_features, out_features):
12356*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12357*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(in_features, out_features))
12358*da0073e9SAndroid Build Coastguard Worker                self.register_buffer('buf1', torch.randn(out_features))
12359*da0073e9SAndroid Build Coastguard Worker                self.layer = Layer(out_features, out_features)
12360*da0073e9SAndroid Build Coastguard Worker
12361*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12362*da0073e9SAndroid Build Coastguard Worker                return self.layer(x @ self.weight + self.buf1)
12363*da0073e9SAndroid Build Coastguard Worker
12364*da0073e9SAndroid Build Coastguard Worker        with torch.device('meta'):
12365*da0073e9SAndroid Build Coastguard Worker            m = MyModule(3, 5)
12366*da0073e9SAndroid Build Coastguard Worker
12367*da0073e9SAndroid Build Coastguard Worker        m.to_empty(device=device, recurse=False)
12368*da0073e9SAndroid Build Coastguard Worker
12369*da0073e9SAndroid Build Coastguard Worker        # params/buffers of parent should have been materialized on device
12370*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not m.weight.is_meta)
12371*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not m.buf1.is_meta)
12372*da0073e9SAndroid Build Coastguard Worker
12373*da0073e9SAndroid Build Coastguard Worker        # parameters/buffers of children submodules should still be on meta
12374*da0073e9SAndroid Build Coastguard Worker        for p in (*m.layer.parameters(), *m.layer.buffers()):
12375*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(p.is_meta)
12376*da0073e9SAndroid Build Coastguard Worker
12377*da0073e9SAndroid Build Coastguard Worker    @skipMeta
12378*da0073e9SAndroid Build Coastguard Worker    def test_skip_init(self, device):
12379*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
12380*da0073e9SAndroid Build Coastguard Worker        m_initialized = torch.nn.Linear(5, 1)
12381*da0073e9SAndroid Build Coastguard Worker        m_initialized.to(device)
12382*da0073e9SAndroid Build Coastguard Worker
12383*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
12384*da0073e9SAndroid Build Coastguard Worker        m_uninitialized = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device=device)
12385*da0073e9SAndroid Build Coastguard Worker
12386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device)
12387*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight))
12388*da0073e9SAndroid Build Coastguard Worker
12389*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm(msg='See https://github.com/pytorch/pytorch/issues/135150')
12390*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # TODO(hvaara): Investigate as possible bug. macOS 13 passes, while 14 and 15 fails.
12391*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
12392*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.double, torch.float, torch.half)
12393*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoderlayer(self, device, dtype):
12394*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
12395*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skip on ROCM due to Flash Attention tolerances")
12396*da0073e9SAndroid Build Coastguard Worker        # this is a deterministic test for TransformerEncoderLayer
12397*da0073e9SAndroid Build Coastguard Worker        d_model = 4
12398*da0073e9SAndroid Build Coastguard Worker        nhead = 2
12399*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 16
12400*da0073e9SAndroid Build Coastguard Worker        dropout = 0.0
12401*da0073e9SAndroid Build Coastguard Worker        bsz = 2
12402*da0073e9SAndroid Build Coastguard Worker
12403*da0073e9SAndroid Build Coastguard Worker        atol = 1e-5
12404*da0073e9SAndroid Build Coastguard Worker        rtol = 1e-7
12405*da0073e9SAndroid Build Coastguard Worker        if "cuda" in device:
12406*da0073e9SAndroid Build Coastguard Worker            atol = 1e-3
12407*da0073e9SAndroid Build Coastguard Worker            rtol = 1e-2
12408*da0073e9SAndroid Build Coastguard Worker
12409*da0073e9SAndroid Build Coastguard Worker        def _test(training, batch_first, atol, rtol):
12410*da0073e9SAndroid Build Coastguard Worker            def perm_fn(x):
12411*da0073e9SAndroid Build Coastguard Worker                return x.transpose(1, 0) if batch_first else x
12412*da0073e9SAndroid Build Coastguard Worker
12413*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
12414*da0073e9SAndroid Build Coastguard Worker                                               batch_first=batch_first, device=device, dtype=dtype)
12415*da0073e9SAndroid Build Coastguard Worker
12416*da0073e9SAndroid Build Coastguard Worker            if not training:
12417*da0073e9SAndroid Build Coastguard Worker                assert dropout == 0
12418*da0073e9SAndroid Build Coastguard Worker                model = model.eval()
12419*da0073e9SAndroid Build Coastguard Worker
12420*da0073e9SAndroid Build Coastguard Worker            # set constant weights of the model
12421*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(model.parameters()):
12422*da0073e9SAndroid Build Coastguard Worker                x = p.data
12423*da0073e9SAndroid Build Coastguard Worker                sz = x.view(-1).size(0)
12424*da0073e9SAndroid Build Coastguard Worker                shape = x.shape
12425*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(torch.arange(0, sz).float().view(shape))
12426*da0073e9SAndroid Build Coastguard Worker                p.data.copy_(x)
12427*da0073e9SAndroid Build Coastguard Worker
12428*da0073e9SAndroid Build Coastguard Worker            # deterministic input
12429*da0073e9SAndroid Build Coastguard Worker            encoder_input = torch.tensor([[[20., 30., 40., 50.]]], device=device, dtype=dtype)
12430*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input)
12431*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor([[[2.258703, 0.127985, -0.697881, 0.170862]]], device=device, dtype=dtype)
12432*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12433*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12434*da0073e9SAndroid Build Coastguard Worker            # 0 values are NOT masked. This shouldn't mask anything.
12435*da0073e9SAndroid Build Coastguard Worker            mask = torch.tensor([[0]], device=device) == 1
12436*da0073e9SAndroid Build Coastguard Worker            # TODO: enable fast path for calls with a mask!
12437*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
12438*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12439*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12440*da0073e9SAndroid Build Coastguard Worker            mask = torch.tensor([[1]], device=device) == 1
12441*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
12442*da0073e9SAndroid Build Coastguard Worker            fast_path_device = result.is_cuda or result.is_cpu
12443*da0073e9SAndroid Build Coastguard Worker            result = result.cpu().detach().numpy()
12444*da0073e9SAndroid Build Coastguard Worker            # Non Fast Paths
12445*da0073e9SAndroid Build Coastguard Worker            if training or not batch_first or TEST_WITH_CROSSREF or not fast_path_device:
12446*da0073e9SAndroid Build Coastguard Worker                # We changed the semenatic, on the non fast path so that fully masked out rows return
12447*da0073e9SAndroid Build Coastguard Worker                # 0 from attention thus NaNs should no longer be present and the output should be nonzero
12448*da0073e9SAndroid Build Coastguard Worker                # due to skip connections
12449*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(not np.isnan(result).any())
12450*da0073e9SAndroid Build Coastguard Worker            else:
12451*da0073e9SAndroid Build Coastguard Worker                # Fast Paths
12452*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(np.isnan(result).all())
12453*da0073e9SAndroid Build Coastguard Worker
12454*da0073e9SAndroid Build Coastguard Worker
12455*da0073e9SAndroid Build Coastguard Worker            # deterministic input
12456*da0073e9SAndroid Build Coastguard Worker            encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
12457*da0073e9SAndroid Build Coastguard Worker                                                  [[5., 6., 7., 8.]]], device=device, dtype=dtype))
12458*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input)
12459*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.272644, 0.119035, -0.691669, 0.153486]],
12460*da0073e9SAndroid Build Coastguard Worker                                               [[2.272644, 0.119035, -0.691669, 0.153486]]], device=device, dtype=dtype))
12461*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12462*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12463*da0073e9SAndroid Build Coastguard Worker            # all 0 which is no masking
12464*da0073e9SAndroid Build Coastguard Worker            mask = torch.tensor([[0, 0]], device=device) == 1
12465*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
12466*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12467*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12468*da0073e9SAndroid Build Coastguard Worker            mask = torch.tensor([[1, 0]], device=device) == 1
12469*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
12470*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.301516, 0.092249, -0.679101, 0.103088]],
12471*da0073e9SAndroid Build Coastguard Worker                                               [[2.301516, 0.092249, -0.679101, 0.103088]]], device=device, dtype=dtype))
12472*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12473*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12474*da0073e9SAndroid Build Coastguard Worker
12475*da0073e9SAndroid Build Coastguard Worker            # deterministic input
12476*da0073e9SAndroid Build Coastguard Worker            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
12477*da0073e9SAndroid Build Coastguard Worker                                                   [0.5387, 0.1655, 0.3565, 0.0471]],
12478*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
12479*da0073e9SAndroid Build Coastguard Worker                                                   [0.1402, 0.0318, 0.7636, 0.1346]],
12480*da0073e9SAndroid Build Coastguard Worker                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
12481*da0073e9SAndroid Build Coastguard Worker                                                   [0.8924, 0.2872, 0.6692, 0.2944]],
12482*da0073e9SAndroid Build Coastguard Worker                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
12483*da0073e9SAndroid Build Coastguard Worker                                                   [0.8645, 0.3513, 0.3064, 0.0767]],
12484*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
12485*da0073e9SAndroid Build Coastguard Worker                                                   [0.3718, 0.4945, 0.9511, 0.0864]]], device=device, dtype=dtype))
12486*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input)
12487*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
12488*da0073e9SAndroid Build Coastguard Worker                                                [2.427987, 0.021213, -0.602496, -0.084103]],
12489*da0073e9SAndroid Build Coastguard Worker                                               [[2.424689, 0.019155, -0.604793, -0.085672],
12490*da0073e9SAndroid Build Coastguard Worker                                                [2.413863, 0.022211, -0.612486, -0.072490]],
12491*da0073e9SAndroid Build Coastguard Worker                                               [[2.433774, 0.021598, -0.598343, -0.087548],
12492*da0073e9SAndroid Build Coastguard Worker                                                [2.425104, 0.019748, -0.604515, -0.084839]],
12493*da0073e9SAndroid Build Coastguard Worker                                               [[2.436185, 0.022682, -0.596625, -0.087261],
12494*da0073e9SAndroid Build Coastguard Worker                                                [2.433556, 0.021891, -0.598509, -0.086832]],
12495*da0073e9SAndroid Build Coastguard Worker                                               [[2.416246, 0.017512, -0.610712, -0.082961],
12496*da0073e9SAndroid Build Coastguard Worker                                                [2.422901, 0.024187, -0.606178, -0.074929]]], device=device, dtype=dtype))
12497*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12498*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12499*da0073e9SAndroid Build Coastguard Worker
12500*da0073e9SAndroid Build Coastguard Worker            # all 0
12501*da0073e9SAndroid Build Coastguard Worker            mask = torch.zeros([2, 5], device=device) == 1
12502*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
12503*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12504*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12505*da0073e9SAndroid Build Coastguard Worker            mask[0, 1] = 1
12506*da0073e9SAndroid Build Coastguard Worker            mask[1, 3] = 1
12507*da0073e9SAndroid Build Coastguard Worker            mask[1, 4] = 1
12508*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input, src_key_padding_mask=mask)
12509*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
12510*da0073e9SAndroid Build Coastguard Worker                                                [2.428811, 0.021445, -0.601912, -0.084252]],
12511*da0073e9SAndroid Build Coastguard Worker                                               [[2.425009, 0.019155, -0.604566, -0.085899],
12512*da0073e9SAndroid Build Coastguard Worker                                                [2.415408, 0.02249 , -0.611415, -0.073]],
12513*da0073e9SAndroid Build Coastguard Worker                                               [[2.434199, 0.021682, -0.598039, -0.087699],
12514*da0073e9SAndroid Build Coastguard Worker                                                [2.42598, 0.019941, -0.603896, -0.085091]],
12515*da0073e9SAndroid Build Coastguard Worker                                               [[2.436457, 0.022736, -0.59643 , -0.08736],
12516*da0073e9SAndroid Build Coastguard Worker                                                [2.434021, 0.022093, -0.598179, -0.08679]],
12517*da0073e9SAndroid Build Coastguard Worker                                               [[2.416531, 0.017498, -0.610513, -0.083181],
12518*da0073e9SAndroid Build Coastguard Worker                                                [2.4242, 0.024653, -0.605266, -0.074959]]], device=device, dtype=dtype))
12519*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, ref_output.shape)
12520*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12521*da0073e9SAndroid Build Coastguard Worker
12522*da0073e9SAndroid Build Coastguard Worker            # NestedTensor is only supported for the fast path
12523*da0073e9SAndroid Build Coastguard Worker            # currently, which won't be used if training.
12524*da0073e9SAndroid Build Coastguard Worker            if (batch_first and not training and
12525*da0073e9SAndroid Build Coastguard Worker                    ('cuda' in str(device) or 'cpu' in str(device)) and not TEST_WITH_CROSSREF):
12526*da0073e9SAndroid Build Coastguard Worker                encoder_input[0][-1] = torch.zeros_like(encoder_input[0][1])
12527*da0073e9SAndroid Build Coastguard Worker                mask = torch.zeros(encoder_input.shape[:-1], device=device, dtype=torch.bool)
12528*da0073e9SAndroid Build Coastguard Worker                mask[0][-1] = True
12529*da0073e9SAndroid Build Coastguard Worker
12530*da0073e9SAndroid Build Coastguard Worker                nt = torch.nested.nested_tensor([encoder_input[0][:-1], encoder_input[1]], device=device)
12531*da0073e9SAndroid Build Coastguard Worker                result = model(nt)
12532*da0073e9SAndroid Build Coastguard Worker                ref_output = torch.tensor(
12533*da0073e9SAndroid Build Coastguard Worker                    [
12534*da0073e9SAndroid Build Coastguard Worker                        [
12535*da0073e9SAndroid Build Coastguard Worker                            [2.4268184, 0.02042419, -0.603311, -0.08476824],
12536*da0073e9SAndroid Build Coastguard Worker                            [2.423306, 0.01889652, -0.6057701, -0.08519465],
12537*da0073e9SAndroid Build Coastguard Worker                            [2.431538, 0.02078694, -0.5999354, -0.08746159],
12538*da0073e9SAndroid Build Coastguard Worker                            [2.4348664, 0.02212971, -0.5975677, -0.08733892],
12539*da0073e9SAndroid Build Coastguard Worker                            [2.423133, 0.02097577, -0.60594773, -0.08113337],
12540*da0073e9SAndroid Build Coastguard Worker                        ],
12541*da0073e9SAndroid Build Coastguard Worker                        [
12542*da0073e9SAndroid Build Coastguard Worker                            [2.4279876, 0.02121329, -0.60249615, -0.08410317],
12543*da0073e9SAndroid Build Coastguard Worker                            [2.4138637, 0.02221113, -0.6124869, -0.07249016],
12544*da0073e9SAndroid Build Coastguard Worker                            [2.4251041, 0.01974815, -0.6045152, -0.08483928],
12545*da0073e9SAndroid Build Coastguard Worker                            [2.4335563, 0.0218913, -0.59850943, -0.08683228],
12546*da0073e9SAndroid Build Coastguard Worker                            [2.4229012, 0.02418739, -0.6061784, -0.07492948],
12547*da0073e9SAndroid Build Coastguard Worker                        ],
12548*da0073e9SAndroid Build Coastguard Worker                    ],
12549*da0073e9SAndroid Build Coastguard Worker                    device=device, dtype=dtype
12550*da0073e9SAndroid Build Coastguard Worker                )
12551*da0073e9SAndroid Build Coastguard Worker                result = result.to_padded_tensor(0)
12552*da0073e9SAndroid Build Coastguard Worker                ref_output[0][-1] = torch.zeros_like(
12553*da0073e9SAndroid Build Coastguard Worker                    ref_output[0][-1], device=device, dtype=dtype
12554*da0073e9SAndroid Build Coastguard Worker                )
12555*da0073e9SAndroid Build Coastguard Worker                result[0][-1] = torch.zeros_like(
12556*da0073e9SAndroid Build Coastguard Worker                    result[0][-1], device=device, dtype=dtype
12557*da0073e9SAndroid Build Coastguard Worker                )
12558*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
12559*da0073e9SAndroid Build Coastguard Worker                if 'cuda' in device:
12560*da0073e9SAndroid Build Coastguard Worker                    if dtype == torch.float:
12561*da0073e9SAndroid Build Coastguard Worker                        atol = 2e-4
12562*da0073e9SAndroid Build Coastguard Worker                        rtol = 4e-3
12563*da0073e9SAndroid Build Coastguard Worker                    else:
12564*da0073e9SAndroid Build Coastguard Worker                        atol = 7e-4
12565*da0073e9SAndroid Build Coastguard Worker                        rtol = 2e-2
12566*da0073e9SAndroid Build Coastguard Worker                    torch.testing.assert_close(result, ref_output, atol=atol, rtol=rtol)
12567*da0073e9SAndroid Build Coastguard Worker                else:
12568*da0073e9SAndroid Build Coastguard Worker                    torch.testing.assert_close(result, ref_output)
12569*da0073e9SAndroid Build Coastguard Worker
12570*da0073e9SAndroid Build Coastguard Worker
12571*da0073e9SAndroid Build Coastguard Worker        for batch_first in (True, False):
12572*da0073e9SAndroid Build Coastguard Worker            for training in (True, False):
12573*da0073e9SAndroid Build Coastguard Worker                if training:
12574*da0073e9SAndroid Build Coastguard Worker                    cm = contextlib.nullcontext()
12575*da0073e9SAndroid Build Coastguard Worker                else:
12576*da0073e9SAndroid Build Coastguard Worker                    # Fast path requires inference mode.
12577*da0073e9SAndroid Build Coastguard Worker                    cm = torch.no_grad()
12578*da0073e9SAndroid Build Coastguard Worker                with cm:
12579*da0073e9SAndroid Build Coastguard Worker                    _test(batch_first=batch_first, training=training, atol=atol, rtol=rtol)
12580*da0073e9SAndroid Build Coastguard Worker
12581*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
12582*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
12583*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoderlayer_fast_path(self, device, dtype):
12584*da0073e9SAndroid Build Coastguard Worker        """
12585*da0073e9SAndroid Build Coastguard Worker        Test transformer fast path on CPU with different valid mask types and shapes
12586*da0073e9SAndroid Build Coastguard Worker        """
12587*da0073e9SAndroid Build Coastguard Worker        d_model = 512
12588*da0073e9SAndroid Build Coastguard Worker        nhead = 8
12589*da0073e9SAndroid Build Coastguard Worker        batch_size = 32
12590*da0073e9SAndroid Build Coastguard Worker        src_len = 10
12591*da0073e9SAndroid Build Coastguard Worker
12592*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True,
12593*da0073e9SAndroid Build Coastguard Worker                                                 device=device, dtype=dtype, dropout=0)
12594*da0073e9SAndroid Build Coastguard Worker        model.eval()
12595*da0073e9SAndroid Build Coastguard Worker
12596*da0073e9SAndroid Build Coastguard Worker        # Batched inputs
12597*da0073e9SAndroid Build Coastguard Worker        src = torch.rand(batch_size, src_len, 512, dtype=dtype)
12598*da0073e9SAndroid Build Coastguard Worker
12599*da0073e9SAndroid Build Coastguard Worker        # Attention mask of shape (src_len, src_len)
12600*da0073e9SAndroid Build Coastguard Worker        src_mask = torch.zeros(src_len, src_len).to(torch.bool)
12601*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
12602*da0073e9SAndroid Build Coastguard Worker            model(src, src_mask=src_mask)
12603*da0073e9SAndroid Build Coastguard Worker
12604*da0073e9SAndroid Build Coastguard Worker        # Padding mask of shape (batch_size, src_len)
12605*da0073e9SAndroid Build Coastguard Worker        src_key_padding_mask = torch.zeros(batch_size, src_len).to(torch.bool)
12606*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
12607*da0073e9SAndroid Build Coastguard Worker            model(src, src_key_padding_mask=src_key_padding_mask)
12608*da0073e9SAndroid Build Coastguard Worker
12609*da0073e9SAndroid Build Coastguard Worker        # Provide both masks
12610*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
12611*da0073e9SAndroid Build Coastguard Worker            model(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
12612*da0073e9SAndroid Build Coastguard Worker
12613*da0073e9SAndroid Build Coastguard Worker
12614*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
12615*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float)
12616*da0073e9SAndroid Build Coastguard Worker    def test_transformerencoderlayer_gelu(self, device, dtype):
12617*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
12618*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Skip on ROCM due to Flash Attention tolerances")
12619*da0073e9SAndroid Build Coastguard Worker        # this is a deterministic test for TransformerEncoderLayer with gelu activation
12620*da0073e9SAndroid Build Coastguard Worker        d_model = 4
12621*da0073e9SAndroid Build Coastguard Worker        nhead = 2
12622*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 16
12623*da0073e9SAndroid Build Coastguard Worker        dropout = 0.0
12624*da0073e9SAndroid Build Coastguard Worker        bsz = 2
12625*da0073e9SAndroid Build Coastguard Worker
12626*da0073e9SAndroid Build Coastguard Worker        atol = 0
12627*da0073e9SAndroid Build Coastguard Worker        rtol = 1e-5
12628*da0073e9SAndroid Build Coastguard Worker        if "cuda" in device:
12629*da0073e9SAndroid Build Coastguard Worker            atol = 1e-3
12630*da0073e9SAndroid Build Coastguard Worker            rtol = 1e-2
12631*da0073e9SAndroid Build Coastguard Worker
12632*da0073e9SAndroid Build Coastguard Worker        def _test(activation, batch_first, training):
12633*da0073e9SAndroid Build Coastguard Worker            def perm_fn(x):
12634*da0073e9SAndroid Build Coastguard Worker                return x.transpose(1, 0) if batch_first else x
12635*da0073e9SAndroid Build Coastguard Worker
12636*da0073e9SAndroid Build Coastguard Worker            model = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
12637*da0073e9SAndroid Build Coastguard Worker                                               activation, batch_first=batch_first, device=device, dtype=dtype)
12638*da0073e9SAndroid Build Coastguard Worker            if not training:
12639*da0073e9SAndroid Build Coastguard Worker                assert dropout == 0
12640*da0073e9SAndroid Build Coastguard Worker                model = model.eval()
12641*da0073e9SAndroid Build Coastguard Worker
12642*da0073e9SAndroid Build Coastguard Worker            # set constant weights of the model
12643*da0073e9SAndroid Build Coastguard Worker            for idx, p in enumerate(model.parameters()):
12644*da0073e9SAndroid Build Coastguard Worker                x = p.data
12645*da0073e9SAndroid Build Coastguard Worker                sz = x.view(-1).size(0)
12646*da0073e9SAndroid Build Coastguard Worker                shape = x.shape
12647*da0073e9SAndroid Build Coastguard Worker                x = torch.cos(torch.arange(0, sz).float().view(shape))
12648*da0073e9SAndroid Build Coastguard Worker                p.data.copy_(x)
12649*da0073e9SAndroid Build Coastguard Worker
12650*da0073e9SAndroid Build Coastguard Worker            # deterministic input
12651*da0073e9SAndroid Build Coastguard Worker            encoder_input = torch.tensor([[[20., 30., 40., 50.]]], device=device, dtype=dtype)
12652*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input)
12653*da0073e9SAndroid Build Coastguard Worker            ref_output = torch.tensor([[[2.249815, 0.131006, -0.702199, 0.177868]]], device=device, dtype=dtype)
12654*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
12655*da0073e9SAndroid Build Coastguard Worker
12656*da0073e9SAndroid Build Coastguard Worker            # deterministic input
12657*da0073e9SAndroid Build Coastguard Worker            encoder_input = perm_fn(torch.tensor([[[1., 2., 3., 4.]],
12658*da0073e9SAndroid Build Coastguard Worker                                                  [[5., 6., 7., 8.]]], device=device, dtype=dtype))
12659*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input)
12660*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.264103, 0.121417, -0.696012, 0.159724]],
12661*da0073e9SAndroid Build Coastguard Worker                                               [[2.264103, 0.121417, -0.696012, 0.159724]]], device=device, dtype=dtype))
12662*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
12663*da0073e9SAndroid Build Coastguard Worker
12664*da0073e9SAndroid Build Coastguard Worker            # deterministic input
12665*da0073e9SAndroid Build Coastguard Worker            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
12666*da0073e9SAndroid Build Coastguard Worker                                                  [0.5387, 0.1655, 0.3565, 0.0471]],
12667*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
12668*da0073e9SAndroid Build Coastguard Worker                                                  [0.1402, 0.0318, 0.7636, 0.1346]],
12669*da0073e9SAndroid Build Coastguard Worker                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
12670*da0073e9SAndroid Build Coastguard Worker                                                  [0.8924, 0.2872, 0.6692, 0.2944]],
12671*da0073e9SAndroid Build Coastguard Worker                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
12672*da0073e9SAndroid Build Coastguard Worker                                                  [0.8645, 0.3513, 0.3064, 0.0767]],
12673*da0073e9SAndroid Build Coastguard Worker                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
12674*da0073e9SAndroid Build Coastguard Worker                                                  [0.3718, 0.4945, 0.9511, 0.0864]]], device=device, dtype=dtype))
12675*da0073e9SAndroid Build Coastguard Worker            result = model(encoder_input)
12676*da0073e9SAndroid Build Coastguard Worker            ref_output = perm_fn(torch.tensor([[[2.42163188, 0.03227153, -0.60714219, -0.05908082],
12677*da0073e9SAndroid Build Coastguard Worker                                                [2.42151276, 0.03302179, -0.60722523, -0.05762651]],
12678*da0073e9SAndroid Build Coastguard Worker                                               [[2.41926761, 0.02974034, -0.60879519, -0.0621269],
12679*da0073e9SAndroid Build Coastguard Worker                                                [2.41626395, 0.03539356, -0.61087842, -0.04978623]],
12680*da0073e9SAndroid Build Coastguard Worker                                               [[2.42382808, 0.03218872, -0.6055963, -0.06073591],
12681*da0073e9SAndroid Build Coastguard Worker                                                [2.41983477, 0.03085259, -0.60840145, -0.06046414]],
12682*da0073e9SAndroid Build Coastguard Worker                                               [[2.42500749, 0.03328855, -0.60476388, -0.0595334],
12683*da0073e9SAndroid Build Coastguard Worker                                                [2.4237977, 0.03290575, -0.60561789, -0.05940082]],
12684*da0073e9SAndroid Build Coastguard Worker                                               [[2.41383916, 0.02686345, -0.61256377, -0.06380707],
12685*da0073e9SAndroid Build Coastguard Worker                                                [2.42000277, 0.03800944, -0.60824798, -0.04754947]]], device=device, dtype=dtype))
12686*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(result, ref_output, rtol=rtol, atol=atol)
12687*da0073e9SAndroid Build Coastguard Worker        for activation, batch_first, training in product(('gelu', F.gelu, nn.GELU()), (True, False), (True, False)):
12688*da0073e9SAndroid Build Coastguard Worker            # Fast path requires inference mode.
12689*da0073e9SAndroid Build Coastguard Worker            if training:
12690*da0073e9SAndroid Build Coastguard Worker                cm = contextlib.nullcontext()
12691*da0073e9SAndroid Build Coastguard Worker            else:
12692*da0073e9SAndroid Build Coastguard Worker                cm = torch.no_grad()
12693*da0073e9SAndroid Build Coastguard Worker            with cm:
12694*da0073e9SAndroid Build Coastguard Worker                _test(activation=activation, batch_first=batch_first, training=training)
12695*da0073e9SAndroid Build Coastguard Worker
12696*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # RuntimeError: foreach=True was passed, but can't use the foreach API on mps tensors
12697*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('foreach', (False, True))
12698*da0073e9SAndroid Build Coastguard Worker    def test_clip_grad_value(self, foreach, device):
12699*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'xla' and foreach:
12700*da0073e9SAndroid Build Coastguard Worker            raise SkipTest('foreach not supported on XLA')
12701*da0073e9SAndroid Build Coastguard Worker
12702*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(10, 10).to(device)
12703*da0073e9SAndroid Build Coastguard Worker        clip_value = 2.5
12704*da0073e9SAndroid Build Coastguard Worker
12705*da0073e9SAndroid Build Coastguard Worker        grad_w, grad_b = torch.arange(-50., 50, device=device).view(10, 10).div_(5), torch.ones(10, device=device).mul_(2)
12706*da0073e9SAndroid Build Coastguard Worker        for grad_list in [[grad_w, grad_b], [grad_w, None]]:
12707*da0073e9SAndroid Build Coastguard Worker            for p, g in zip(l.parameters(), grad_list):
12708*da0073e9SAndroid Build Coastguard Worker                p._grad = g.clone().view_as(p.data) if g is not None else g
12709*da0073e9SAndroid Build Coastguard Worker
12710*da0073e9SAndroid Build Coastguard Worker            clip_grad_value_(l.parameters(), clip_value, foreach=foreach)
12711*da0073e9SAndroid Build Coastguard Worker            for p in filter(lambda p: p.grad is not None, l.parameters()):
12712*da0073e9SAndroid Build Coastguard Worker                self.assertLessEqual(p.grad.data.max(), clip_value)
12713*da0073e9SAndroid Build Coastguard Worker                self.assertGreaterEqual(p.grad.data.min(), -clip_value)
12714*da0073e9SAndroid Build Coastguard Worker
12715*da0073e9SAndroid Build Coastguard Worker        # Should accept a single Tensor as input
12716*da0073e9SAndroid Build Coastguard Worker        p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
12717*da0073e9SAndroid Build Coastguard Worker        g = torch.arange(-50., 50, device=device).view(10, 10).div_(5)
12718*da0073e9SAndroid Build Coastguard Worker        p1._grad = g.clone()
12719*da0073e9SAndroid Build Coastguard Worker        p2._grad = g.clone()
12720*da0073e9SAndroid Build Coastguard Worker        clip_grad_value_(p1, clip_value, foreach=foreach)
12721*da0073e9SAndroid Build Coastguard Worker        clip_grad_value_([p2], clip_value, foreach=foreach)
12722*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p1.grad, p2.grad)
12723*da0073e9SAndroid Build Coastguard Worker
12724*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # TypeError: the MPS framework doesn't support float64
12725*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('foreach', (False, True))
12726*da0073e9SAndroid Build Coastguard Worker    @parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf'))
12727*da0073e9SAndroid Build Coastguard Worker    def test_clip_grad_norm(self, norm_type, foreach, device):
12728*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == 'xla' and foreach:
12729*da0073e9SAndroid Build Coastguard Worker            raise SkipTest('foreach not supported on XLA')
12730*da0073e9SAndroid Build Coastguard Worker
12731*da0073e9SAndroid Build Coastguard Worker        l = nn.Linear(10, 10).to(device)
12732*da0073e9SAndroid Build Coastguard Worker        max_norm = 2
12733*da0073e9SAndroid Build Coastguard Worker
12734*da0073e9SAndroid Build Coastguard Worker        def compute_norm(norm_type):
12735*da0073e9SAndroid Build Coastguard Worker            norm_type = float(norm_type)
12736*da0073e9SAndroid Build Coastguard Worker            if norm_type != inf:
12737*da0073e9SAndroid Build Coastguard Worker                total_norm = 0
12738*da0073e9SAndroid Build Coastguard Worker                for p in l.parameters():
12739*da0073e9SAndroid Build Coastguard Worker                    total_norm += p.grad.data.abs().pow(norm_type).sum()
12740*da0073e9SAndroid Build Coastguard Worker                return pow(total_norm, 1. / norm_type)
12741*da0073e9SAndroid Build Coastguard Worker            else:
12742*da0073e9SAndroid Build Coastguard Worker                return max(p.grad.data.abs().max() for p in l.parameters())
12743*da0073e9SAndroid Build Coastguard Worker
12744*da0073e9SAndroid Build Coastguard Worker        def compare_scaling(grads):
12745*da0073e9SAndroid Build Coastguard Worker            p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)]
12746*da0073e9SAndroid Build Coastguard Worker            scale = torch.cat(p_scale)
12747*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scale.std(), 0)
12748*da0073e9SAndroid Build Coastguard Worker            return scale[0]
12749*da0073e9SAndroid Build Coastguard Worker
12750*da0073e9SAndroid Build Coastguard Worker        grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000)
12751*da0073e9SAndroid Build Coastguard Worker        for p, g in zip(l.parameters(), grads):
12752*da0073e9SAndroid Build Coastguard Worker            p._grad = g.clone().view_as(p.data)
12753*da0073e9SAndroid Build Coastguard Worker        norm_before = compute_norm(norm_type)
12754*da0073e9SAndroid Build Coastguard Worker        norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
12755*da0073e9SAndroid Build Coastguard Worker        norm_after = compute_norm(norm_type)
12756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(norm, norm_before)
12757*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(norm_after, max_norm)
12758*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(norm_after, norm_before)
12759*da0073e9SAndroid Build Coastguard Worker        compare_scaling(grads)
12760*da0073e9SAndroid Build Coastguard Worker
12761*da0073e9SAndroid Build Coastguard Worker        # Small gradients should be left unchanged
12762*da0073e9SAndroid Build Coastguard Worker        grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500)
12763*da0073e9SAndroid Build Coastguard Worker        for p, g in zip(l.parameters(), grads):
12764*da0073e9SAndroid Build Coastguard Worker            p.grad.data.copy_(g)
12765*da0073e9SAndroid Build Coastguard Worker        norm_before = compute_norm(norm_type)
12766*da0073e9SAndroid Build Coastguard Worker        norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
12767*da0073e9SAndroid Build Coastguard Worker        norm_after = compute_norm(norm_type)
12768*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(norm, norm_before)
12769*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(norm_before, norm_after)
12770*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(norm_after, max_norm)
12771*da0073e9SAndroid Build Coastguard Worker        scale = compare_scaling(grads)
12772*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, 1)
12773*da0073e9SAndroid Build Coastguard Worker
12774*da0073e9SAndroid Build Coastguard Worker        # Should accept a single Tensor as input
12775*da0073e9SAndroid Build Coastguard Worker        p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
12776*da0073e9SAndroid Build Coastguard Worker        g = torch.arange(1., 101, device=device).view(10, 10)
12777*da0073e9SAndroid Build Coastguard Worker        p1._grad = g.clone()
12778*da0073e9SAndroid Build Coastguard Worker        p2._grad = g.clone()
12779*da0073e9SAndroid Build Coastguard Worker        clip_grad_norm_(p1, max_norm, norm_type=norm_type, foreach=foreach)
12780*da0073e9SAndroid Build Coastguard Worker        clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach)
12781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(p1.grad, p2.grad)
12782*da0073e9SAndroid Build Coastguard Worker
12783*da0073e9SAndroid Build Coastguard Worker    # reference issue: https://github.com/pytorch/pytorch/issues/111484
12784*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
12785*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("42GB", "cuda")
12786*da0073e9SAndroid Build Coastguard Worker    def test_softmax_forward_64bit_indexing(self, device):
12787*da0073e9SAndroid Build Coastguard Worker        batch_size = 70
12788*da0073e9SAndroid Build Coastguard Worker        seq_len = 2048
12789*da0073e9SAndroid Build Coastguard Worker        vocab_size = 50000
12790*da0073e9SAndroid Build Coastguard Worker
12791*da0073e9SAndroid Build Coastguard Worker        shift_labels = torch.zeros(batch_size, seq_len - 1, dtype=torch.long, device=device)
12792*da0073e9SAndroid Build Coastguard Worker        logits = torch.ones(batch_size, seq_len - 1, vocab_size, dtype=torch.float16, device=device)
12793*da0073e9SAndroid Build Coastguard Worker        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
12794*da0073e9SAndroid Build Coastguard Worker        nll = loss_fct(logits.permute(0, 2, 1), shift_labels).float()
12795*da0073e9SAndroid Build Coastguard Worker        rtol, atol = torch.testing._comparison.get_tolerances(torch.float16, rtol=None, atol=None)
12796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nll, torch.ones_like(nll) * torch.log(torch.tensor(vocab_size)), rtol=rtol, atol=atol)
12797*da0073e9SAndroid Build Coastguard Worker
12798*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
12799*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("20GB", "cuda")
12800*da0073e9SAndroid Build Coastguard Worker    def test_softmax_backward_64bit_indexing(self, device):
12801*da0073e9SAndroid Build Coastguard Worker        for numel in (2147483650, 2147483650 + 1):
12802*da0073e9SAndroid Build Coastguard Worker            x = torch.empty([1, 1, numel], device=device, dtype=torch.float16)
12803*da0073e9SAndroid Build Coastguard Worker            x.fill_(1.0 / numel)
12804*da0073e9SAndroid Build Coastguard Worker            out = torch._softmax_backward_data(x, x, 2, x.dtype)
12805*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[0, 0, 0], 1 / numel)
12806*da0073e9SAndroid Build Coastguard Worker
12807*da0073e9SAndroid Build Coastguard Worker    # reference issue: https://github.com/pytorch/pytorch/issues/68248
12808*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
12809*da0073e9SAndroid Build Coastguard Worker    def test_adaptiveavg_pool1d_shmem(self, device):
12810*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 256, 1, 5000, device=device).to(memory_format=torch.channels_last)
12811*da0073e9SAndroid Build Coastguard Worker        x_cpu = x.cpu()
12812*da0073e9SAndroid Build Coastguard Worker        x_cpu.requires_grad_()
12813*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_()
12814*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.adaptive_avg_pool2d(x, (1, 256))
12815*da0073e9SAndroid Build Coastguard Worker        y_cpu = torch.nn.functional.adaptive_avg_pool2d(x_cpu, (1, 256))
12816*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn_like(y)
12817*da0073e9SAndroid Build Coastguard Worker        grad_cpu = grad.cpu()
12818*da0073e9SAndroid Build Coastguard Worker        y.backward(grad)
12819*da0073e9SAndroid Build Coastguard Worker        y_cpu.backward(grad_cpu)
12820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_cpu.grad)
12821*da0073e9SAndroid Build Coastguard Worker
12822*da0073e9SAndroid Build Coastguard Worker    @skipMeta
12823*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMPS  # NotImplementedError: aten::channel_shuffle https://github.com/pytorch/pytorch/issues/77764
12824*da0073e9SAndroid Build Coastguard Worker    def test_channel_shuffle(self, device):
12825*da0073e9SAndroid Build Coastguard Worker        #  3D tensor
12826*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
12827*da0073e9SAndroid Build Coastguard Worker            [[[1, 2],
12828*da0073e9SAndroid Build Coastguard Worker              [5, 6],
12829*da0073e9SAndroid Build Coastguard Worker              [9, 10],
12830*da0073e9SAndroid Build Coastguard Worker              [13, 14],
12831*da0073e9SAndroid Build Coastguard Worker              ]], device=device
12832*da0073e9SAndroid Build Coastguard Worker        )
12833*da0073e9SAndroid Build Coastguard Worker        y_ref = torch.tensor(
12834*da0073e9SAndroid Build Coastguard Worker            [[[1, 2],
12835*da0073e9SAndroid Build Coastguard Worker              [9, 10],
12836*da0073e9SAndroid Build Coastguard Worker              [5, 6],
12837*da0073e9SAndroid Build Coastguard Worker              [13, 14],
12838*da0073e9SAndroid Build Coastguard Worker              ]], device=device
12839*da0073e9SAndroid Build Coastguard Worker        )
12840*da0073e9SAndroid Build Coastguard Worker        #  ChannelsFirst
12841*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
12842*da0073e9SAndroid Build Coastguard Worker            y = F.channel_shuffle(x, 2).to(device)
12843*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
12844*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_ref)
12845*da0073e9SAndroid Build Coastguard Worker        #  ChannelsLast not supported for 3dim
12846*da0073e9SAndroid Build Coastguard Worker
12847*da0073e9SAndroid Build Coastguard Worker        #  4D tensor
12848*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
12849*da0073e9SAndroid Build Coastguard Worker            [[[[1, 2],
12850*da0073e9SAndroid Build Coastguard Worker               [3, 4]],
12851*da0073e9SAndroid Build Coastguard Worker              [[5, 6],
12852*da0073e9SAndroid Build Coastguard Worker               [7, 8]],
12853*da0073e9SAndroid Build Coastguard Worker              [[9, 10],
12854*da0073e9SAndroid Build Coastguard Worker               [11, 12]],
12855*da0073e9SAndroid Build Coastguard Worker              [[13, 14],
12856*da0073e9SAndroid Build Coastguard Worker               [15, 16]],
12857*da0073e9SAndroid Build Coastguard Worker              ]], device=device
12858*da0073e9SAndroid Build Coastguard Worker        )
12859*da0073e9SAndroid Build Coastguard Worker        y_ref = torch.tensor(
12860*da0073e9SAndroid Build Coastguard Worker            [[[[1, 2],
12861*da0073e9SAndroid Build Coastguard Worker               [3, 4]],
12862*da0073e9SAndroid Build Coastguard Worker              [[9, 10],
12863*da0073e9SAndroid Build Coastguard Worker               [11, 12]],
12864*da0073e9SAndroid Build Coastguard Worker              [[5, 6],
12865*da0073e9SAndroid Build Coastguard Worker               [7, 8]],
12866*da0073e9SAndroid Build Coastguard Worker              [[13, 14],
12867*da0073e9SAndroid Build Coastguard Worker               [15, 16]],
12868*da0073e9SAndroid Build Coastguard Worker              ]], device=device
12869*da0073e9SAndroid Build Coastguard Worker        )
12870*da0073e9SAndroid Build Coastguard Worker        #  ChannelsFirst NCHW
12871*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
12872*da0073e9SAndroid Build Coastguard Worker            y = F.channel_shuffle(x, 2).to(device)
12873*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
12874*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_ref)
12875*da0073e9SAndroid Build Coastguard Worker        #  ChannelsLast NHWC
12876*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
12877*da0073e9SAndroid Build Coastguard Worker            y = F.channel_shuffle(x.contiguous(memory_format=torch.channels_last), 2).to(device)
12878*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
12879*da0073e9SAndroid Build Coastguard Worker        y = y.contiguous(memory_format=torch.contiguous_format)
12880*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_ref)
12881*da0073e9SAndroid Build Coastguard Worker
12882*da0073e9SAndroid Build Coastguard Worker        #  5D tensor
12883*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(
12884*da0073e9SAndroid Build Coastguard Worker            [[[[[1, 2],
12885*da0073e9SAndroid Build Coastguard Worker               [3, 4]]],
12886*da0073e9SAndroid Build Coastguard Worker              [[[5, 6],
12887*da0073e9SAndroid Build Coastguard Worker               [7, 8]]],
12888*da0073e9SAndroid Build Coastguard Worker              [[[9, 10],
12889*da0073e9SAndroid Build Coastguard Worker               [11, 12]]],
12890*da0073e9SAndroid Build Coastguard Worker              [[[13, 14],
12891*da0073e9SAndroid Build Coastguard Worker               [15, 16]]],
12892*da0073e9SAndroid Build Coastguard Worker              ]], device=device
12893*da0073e9SAndroid Build Coastguard Worker        )
12894*da0073e9SAndroid Build Coastguard Worker        y_ref = torch.tensor(
12895*da0073e9SAndroid Build Coastguard Worker            [[[[[1, 2],
12896*da0073e9SAndroid Build Coastguard Worker               [3, 4]]],
12897*da0073e9SAndroid Build Coastguard Worker              [[[9, 10],
12898*da0073e9SAndroid Build Coastguard Worker               [11, 12]]],
12899*da0073e9SAndroid Build Coastguard Worker              [[[5, 6],
12900*da0073e9SAndroid Build Coastguard Worker               [7, 8]]],
12901*da0073e9SAndroid Build Coastguard Worker              [[[13, 14],
12902*da0073e9SAndroid Build Coastguard Worker               [15, 16]]],
12903*da0073e9SAndroid Build Coastguard Worker              ]], device=device
12904*da0073e9SAndroid Build Coastguard Worker        )
12905*da0073e9SAndroid Build Coastguard Worker        #  ChannelsFirst NCHW
12906*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
12907*da0073e9SAndroid Build Coastguard Worker            y = F.channel_shuffle(x, 2).to(device)
12908*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
12909*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_ref)
12910*da0073e9SAndroid Build Coastguard Worker        #  ChannelsLast NHWC
12911*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
12912*da0073e9SAndroid Build Coastguard Worker            y = F.channel_shuffle(x.contiguous(memory_format=torch.channels_last_3d), 2).to(device)
12913*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
12914*da0073e9SAndroid Build Coastguard Worker        y = y.contiguous(memory_format=torch.contiguous_format)
12915*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_ref)
12916*da0073e9SAndroid Build Coastguard Worker
12917*da0073e9SAndroid Build Coastguard Worker
12918*da0073e9SAndroid Build Coastguard Workerclass TestFunctionalPickle(TestCase):
12919*da0073e9SAndroid Build Coastguard Worker
12920*da0073e9SAndroid Build Coastguard Worker    # issue gh-38137
12921*da0073e9SAndroid Build Coastguard Worker    def test_pickle_softsign(self):
12922*da0073e9SAndroid Build Coastguard Worker        # Make sure it does not throw an exception
12923*da0073e9SAndroid Build Coastguard Worker        s = pickle.dumps(F.softsign)
12924*da0073e9SAndroid Build Coastguard Worker
12925*da0073e9SAndroid Build Coastguard Worker
12926*da0073e9SAndroid Build Coastguard Workerclass TestFusionUtils(TestCase):
12927*da0073e9SAndroid Build Coastguard Worker    def test_fuse_conv_bn_requires_grad(self):
12928*da0073e9SAndroid Build Coastguard Worker        conv = torch.nn.Conv2d(3, 3, 3)
12929*da0073e9SAndroid Build Coastguard Worker        bn = torch.nn.BatchNorm2d(3)
12930*da0073e9SAndroid Build Coastguard Worker        cases = itertools.product([True, False], [True, False])
12931*da0073e9SAndroid Build Coastguard Worker        for w_rg, b_rg in cases:
12932*da0073e9SAndroid Build Coastguard Worker            conv.weight.requires_grad = w_rg
12933*da0073e9SAndroid Build Coastguard Worker            conv.bias.requires_grad = b_rg
12934*da0073e9SAndroid Build Coastguard Worker            weight, bias = \
12935*da0073e9SAndroid Build Coastguard Worker                fuse_conv_bn_weights(conv.weight, conv.bias,
12936*da0073e9SAndroid Build Coastguard Worker                                     bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
12937*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(weight.requires_grad, w_rg)
12938*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bias.requires_grad, b_rg)
12939*da0073e9SAndroid Build Coastguard Worker
12940*da0073e9SAndroid Build Coastguard Worker    def test_fuse_linear_bn_requires_grad(self):
12941*da0073e9SAndroid Build Coastguard Worker        linear = torch.nn.Linear(3, 3)
12942*da0073e9SAndroid Build Coastguard Worker        bn = torch.nn.BatchNorm1d(3)
12943*da0073e9SAndroid Build Coastguard Worker        cases = itertools.product([True, False], [True, False])
12944*da0073e9SAndroid Build Coastguard Worker        for w_rg, b_rg in cases:
12945*da0073e9SAndroid Build Coastguard Worker            linear.weight.requires_grad = w_rg
12946*da0073e9SAndroid Build Coastguard Worker            linear.bias.requires_grad = b_rg
12947*da0073e9SAndroid Build Coastguard Worker            weight, bias = \
12948*da0073e9SAndroid Build Coastguard Worker                fuse_linear_bn_weights(linear.weight, linear.bias,
12949*da0073e9SAndroid Build Coastguard Worker                                       bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
12950*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(weight.requires_grad, w_rg)
12951*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bias.requires_grad, b_rg)
12952*da0073e9SAndroid Build Coastguard Worker
12953*da0073e9SAndroid Build Coastguard Workerclass TestUtils(TestCase):
12954*da0073e9SAndroid Build Coastguard Worker    def test_consume_prefix_in_state_dict_if_present(self):
12955*da0073e9SAndroid Build Coastguard Worker        class Block(nn.Module):
12956*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12957*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12958*da0073e9SAndroid Build Coastguard Worker                self.conv1 = nn.Conv2d(3, 3, 3, bias=True)
12959*da0073e9SAndroid Build Coastguard Worker                self.conv2 = nn.Conv2d(3, 3, 3, bias=False)
12960*da0073e9SAndroid Build Coastguard Worker
12961*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
12962*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12963*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12964*da0073e9SAndroid Build Coastguard Worker                self.linear1 = nn.Linear(5, 5)
12965*da0073e9SAndroid Build Coastguard Worker                self.linear2 = nn.Linear(5, 5)
12966*da0073e9SAndroid Build Coastguard Worker                net.bn = nn.BatchNorm2d(2)
12967*da0073e9SAndroid Build Coastguard Worker                self.block = Block()
12968*da0073e9SAndroid Build Coastguard Worker
12969*da0073e9SAndroid Build Coastguard Worker        # 0. Case non-DDP model empty state_dict
12970*da0073e9SAndroid Build Coastguard Worker        net = nn.Module()
12971*da0073e9SAndroid Build Coastguard Worker        state_dict = net.state_dict()
12972*da0073e9SAndroid Build Coastguard Worker        nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, 'module.')
12973*da0073e9SAndroid Build Coastguard Worker        # check they are the same preserving order
12974*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(state_dict.keys()), list(net.state_dict().keys()))
12975*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(state_dict._metadata.keys()), list(net.state_dict()._metadata.keys()))
12976*da0073e9SAndroid Build Coastguard Worker
12977*da0073e9SAndroid Build Coastguard Worker        # 1. Case non-DDP model test example state_dict
12978*da0073e9SAndroid Build Coastguard Worker        net = Net()
12979*da0073e9SAndroid Build Coastguard Worker        state_dict = net.state_dict()
12980*da0073e9SAndroid Build Coastguard Worker        nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict, 'module.')
12981*da0073e9SAndroid Build Coastguard Worker        # Check they are the same preserving order
12982*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(state_dict.keys()), list(net.state_dict().keys()))
12983*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(state_dict._metadata.keys()), list(net.state_dict()._metadata.keys()))
12984*da0073e9SAndroid Build Coastguard Worker
12985*da0073e9SAndroid Build Coastguard Worker        # 2. Case DDP model test example state_dict
12986*da0073e9SAndroid Build Coastguard Worker        state_dict = net.state_dict()
12987*da0073e9SAndroid Build Coastguard Worker        metadata = state_dict._metadata
12988*da0073e9SAndroid Build Coastguard Worker        ddp_state_dict = OrderedDict((f'module.{k}', v) for k, v in state_dict.items())
12989*da0073e9SAndroid Build Coastguard Worker        ddp_state_dict._metadata = OrderedDict({'': metadata['']})
12990*da0073e9SAndroid Build Coastguard Worker        ddp_state_dict._metadata.update(('module' if k == '' else f'module.{k}', v) for k, v in metadata.items())
12991*da0073e9SAndroid Build Coastguard Worker        nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, 'module.')
12992*da0073e9SAndroid Build Coastguard Worker        # Check they are the same preserving order
12993*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(state_dict.keys()), list(ddp_state_dict.keys()))
12994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(state_dict._metadata.keys()), list(ddp_state_dict._metadata.keys()))
12995*da0073e9SAndroid Build Coastguard Worker
12996*da0073e9SAndroid Build Coastguard Worker
12997*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNNDeviceType, globals(), allow_mps=True)
12998*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNN)
12999*da0073e9SAndroid Build Coastguard Worker
13000*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
13001*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
13002*da0073e9SAndroid Build Coastguard Worker    run_tests()
13003