xref: /aosp_15_r20/external/pytorch/test/test_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2
3from itertools import chain, product
4from inspect import signature, isgenerator
5from copy import deepcopy
6import tempfile
7from operator import methodcaller
8
9import torch
10
11from torch._subclasses.meta_utils import assert_metadata_eq
12from torch.testing._internal.common_cuda import with_tf32_off
13from torch.testing._internal.common_device_type import (
14    instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
15from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode
16from torch.testing._internal.common_utils import (
17    TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
18    gradgradcheck, parametrize, wrapSwapTensorsTest)
19from unittest.mock import patch, call
20
21
22class TestModule(TestCase):
23    _do_cuda_memory_leak_check = True
24    _do_cuda_non_default_stream = True
25    precision = 1e-5
26    rel_tol = 1e-5
27
28    def _assert_module_parameters_and_buffer_are(self, module, device, dtype):
29        # Check device placement and dtype for created parameters and buffers.
30        # Only verify floating point dtypes since that's what the kwarg or methods
31        # such as `float()` applies to.
32        if not isinstance(device, torch.device):
33            device = torch.device(device)
34
35        def _check_module(items, name, device=device, dtype=dtype):
36            for item_name, item in items:
37                self.assertEqual(
38                    item.device, device,
39                    f'{name} {item_name} is on device {item.device} instead of the expected device {device}')
40                if item.dtype.is_floating_point:
41                    self.assertEqual(
42                        item.dtype, dtype,
43                        f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}')
44        _check_module(module.named_parameters(), "Parameter")
45        _check_module(module.named_buffers(), "Buffer")
46
47    @modules(module_db)
48    def test_forward(self, device, dtype, module_info, training):
49        module_cls = module_info.module_cls
50        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
51                                                       requires_grad=False, training=training)
52        dtype_to_method_caller = {
53            torch.float32: methodcaller("float"),
54            torch.float64: methodcaller("double"),
55        }
56        for module_input in module_inputs:
57            if module_input.forward_input is None:
58                continue
59
60            with freeze_rng_state():
61                # === Instantiate the module. ===
62                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
63                m = module_cls(*args, **kwargs)
64                m.to(device).to(dtype)
65                m.train(training)
66
67                # === Do forward pass. ===
68                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
69                outputs = m(*args, **kwargs)
70
71                # === Compare outputs to a reference if one is specified. ===
72                # TODO: Handle precision
73                reference_fn = module_input.reference_fn
74                if reference_fn is not None:
75                    ref_outputs = reference_fn(m, *args, **kwargs)
76                    self.assertEqual(outputs, ref_outputs)
77
78                # === Use the method call and verify the parameters and buffers ===
79                if dtype in dtype_to_method_caller:
80                    dtype_to_method_caller[dtype](m)
81                    m(*args, **kwargs)
82                    self._assert_module_parameters_and_buffer_are(m, device, dtype)
83
84    # Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
85    # They should be applied to any created parameters and buffers.
86    @modules(module_db)
87    def test_factory_kwargs(self, device, dtype, module_info, training):
88        module_cls = module_info.module_cls
89        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
90                                                       requires_grad=False, training=training)
91        for module_input in module_inputs:
92            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
93
94            # Check if this module creates parameters or registers buffers.
95            # The mock magic here passes through to the real Parameter / register_buffer
96            # logic and is only used to check call inputs.
97            module_creates_params_or_buffers = False
98            parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
99            with patch.object(torch.nn.Parameter, '__new__', parameter_new):
100                register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
101                with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
102                    m = module_cls(*args, **kwargs)
103                    m.train(training)
104
105                    # Check if a parameter or buffer was created with a tensor not passed to the constructor.
106                    constructor_tensors = get_tensors_from(args, kwargs)
107                    for mock in [parameter_new.mock, register_buffer.mock]:
108                        for call_args, call_kwargs in mock.call_args_list:
109                            call_tensors = get_tensors_from(call_args, call_kwargs)
110                            if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors):
111                                module_creates_params_or_buffers = True
112                                break
113
114            if not module_creates_params_or_buffers:
115                continue
116
117            # Instantiate module with the factory kwargs.
118            kwargs.update({
119                'device': device,
120                'dtype': dtype,
121            })
122
123            if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
124                # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
125                uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
126                with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
127                    uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
128                    with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
129                        m = module_cls(*args, **kwargs)
130                        m.train(training)
131                        uninit_param_new.mock.assert_has_calls(
132                            [call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
133                        uninit_buffer_new.mock.assert_has_calls(
134                            [call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
135            else:
136                # Check device placement and dtype for created parameters and buffers.
137                # Only verify floating point dtypes since that's what the kwarg applies to.
138                m = module_cls(*args, **kwargs)
139                m.train(training)
140                self._assert_module_parameters_and_buffer_are(m, device, dtype)
141
142    @onlyCUDA
143    @modules(module_db)
144    def test_multiple_device_transfer(self, device, dtype, module_info, training):
145        module_cls = module_info.module_cls
146        module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
147                                                              requires_grad=False, training=training)
148        module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
149                                                           requires_grad=False, training=training)
150        for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
151            if module_input_device.forward_input is None:
152                continue
153
154            with freeze_rng_state():
155                # === Instantiate the module. ===
156                args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
157                m = module_cls(*args, **kwargs)
158                m.to(device).to(dtype)
159                m.train(training)
160
161                # === Do forward pass on GPU ===
162                input_device_args = module_input_device.forward_input.args
163                input_device_kwargs = module_input_device.forward_input.kwargs
164                m(*input_device_args, **input_device_kwargs)
165                self._assert_module_parameters_and_buffer_are(m, device, dtype)
166
167                # === Move to CPU ===
168                input_cpu_args = module_input_cpu.forward_input.args
169                input_cpu_kwargs = module_input_cpu.forward_input.kwargs
170                m.cpu()
171                m(*input_cpu_args, **input_cpu_kwargs)
172                self._assert_module_parameters_and_buffer_are(m, "cpu", dtype)
173
174                # === Move back to GPU and forward pass ===
175                m.cuda()
176                m(*input_device_args, **input_device_kwargs)
177                self._assert_module_parameters_and_buffer_are(m, device, dtype)
178
179                if torch.cuda.device_count() >= 2:
180                    # === test cross-GPU transfer works
181                    def _to_device1(objs):
182                        if isinstance(objs, (tuple, list)):
183                            return type(objs)(_to_device1(item) for item in objs)
184                        elif isinstance(objs, dict):
185                            return {name: _to_device1(item) for name, item in objs.items()}
186                        elif isinstance(objs, torch.Tensor):
187                            return objs.cuda(1)
188                        else:
189                            return objs
190                    input_device_1_args = _to_device1(input_device_args)
191                    input_device_1_kwargs = _to_device1(input_device_kwargs)
192
193                    m.cuda(1)
194                    with torch.cuda.device(1):
195                        m(*input_device_1_args, **input_device_1_kwargs)
196                    self._assert_module_parameters_and_buffer_are(m, torch.device("cuda:1"), dtype)
197
198    @modules(module_db)
199    def test_repr(self, device, dtype, module_info, training):
200        # Test module can be represented with repr and str without errors.
201        module_cls = module_info.module_cls
202        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
203                                                       requires_grad=False, training=training)
204        for module_input in module_inputs:
205            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
206            m = module_cls(*args, **kwargs)
207            m.to(device).to(dtype)
208            m.train(training)
209
210            # Check that these methods do not raise errors
211            m.__repr__()
212            str(m)
213
214    @modules(module_db)
215    def test_save_load(self, device, dtype, module_info, training):
216        # Test that module can be pickled and unpickled.
217        module_cls = module_info.module_cls
218        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
219                                                       requires_grad=False, training=training)
220        for module_input in module_inputs:
221            if module_input.forward_input is None:
222                continue
223
224            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
225
226            with freeze_rng_state():
227                # === Instantiate the module. ===
228                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
229                m = module_cls(*args, **kwargs)
230                m.to(device).to(dtype)
231                m.train(training)
232                sd = m.state_dict()
233
234                # === Do forward pass. ===
235                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
236                output = m(*args, **kwargs)
237
238                # === Check saved/loaded module gives the same output. ===
239                with tempfile.TemporaryFile() as f:
240                    torch.save(m, f)
241                    f.seek(0)
242                    # weights_only=False as this is legacy code that saves the model
243                    m_copy = torch.load(f, weights_only=False)
244                    output_from_copy = m_copy(*args, **kwargs)
245                    self.assertEqual(output, output_from_copy)
246
247                # === Check saved/loaded state_dict are the same (including weights_only load). ===
248                with tempfile.TemporaryFile() as f:
249                    torch.save(sd, f)
250                    f.seek(0)
251                    sd_copy = torch.load(f)
252                    self.assertEqual(sd_copy, sd)
253                    del sd_copy
254                    f.seek(0)
255                    sd_copy_wo = torch.load(f, weights_only=True)
256                    self.assertEqual(sd_copy_wo, sd)
257
258    @skipMeta
259    @modules([module_info for module_info in module_db
260              if 'inplace' in signature(module_info.module_cls).parameters])
261    def test_check_inplace(self, device, dtype, module_info, training):
262        # Check if the inplace variant of the module gives the same result as the out of place
263        # variant.
264        module_cls = module_info.module_cls
265        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
266                                                       requires_grad=True, training=training)
267        for module_input in module_inputs:
268            if module_input.forward_input is None:
269                continue
270
271            # === Instantiate the module. ===
272            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
273            m_op = module_cls(*args, **kwargs, inplace=False)
274            m_op.to(device).to(dtype)
275            m_op.train(training)
276            m_inplace = module_cls(*args, **kwargs, inplace=True)
277            m_inplace.to(device).to(dtype)
278            m_inplace.train(training)
279
280            # === Inplace modules only supports inplace operations on the first argument ===
281            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
282
283            # ===  Do not allow the first input to be in input_kwargs ===
284            forward_sig = signature(m_op).parameters
285            self.assertGreaterEqual(len(forward_sig), 1)
286            first_param_name = next(iter(forward_sig.items()))
287            self.assertNotIn(first_param_name, input_kwargs)
288
289            # === Out of place operation does not write to original tensor ===
290            self.assertGreaterEqual(len(input_args), 1)
291            input_version = input_args[0]._version
292            with freeze_rng_state():
293                output_op = m_op(*input_args, **input_kwargs)
294            self.assertEqual(input_args[0]._version, input_version)
295
296            # === Check that the inplace operation gives the same result ===
297            input_arg_copy = deepcopy(input_args)
298            input_arg_clone = tuple(i.clone() for i in input_arg_copy)
299            input_clone_version = input_arg_clone[0]._version
300            with freeze_rng_state():
301                output_ip = m_inplace(*input_arg_clone, **input_kwargs)
302            self.assertGreater(input_arg_clone[0]._version, input_clone_version)
303            self.assertEqual(output_op, output_ip)
304
305            # === Check that the gradients are the same ===
306            grad = output_op.data.clone().normal_()
307            output_op.backward(grad)
308            output_ip.backward(grad)
309            self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
310
311    def _traverse_obj(self, obj, func):
312        if isinstance(obj, (tuple, list)):
313            return type(obj)(self._traverse_obj(o, func) for o in obj)
314        elif isgenerator(obj):
315            return tuple(self._traverse_obj(o, func) for o in obj)
316        elif isinstance(obj, dict):
317            return {name: self._traverse_obj(o, func) for name, o in obj.items()}
318        elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
319            return func(obj)
320        else:
321            return obj
322
323    def _retain_grad(self, obj):
324        # gradients needs to be retained to check for grad. This is useful when
325        # non-leafs are present in the graph.
326        def inner_retain_grad(obj):
327            if obj.requires_grad:
328                obj.retain_grad()
329        self._traverse_obj(obj, inner_retain_grad)
330
331    def _get_grads(self, obj):
332        def inner_get_grad(obj):
333            if obj.requires_grad:
334                return obj.grad
335        return self._traverse_obj(obj, inner_get_grad)
336
337    def _zero_grad(self, obj):
338        def inner_zero_grad(obj):
339            if obj.grad is not None:
340                obj.grad = None
341        self._traverse_obj(obj, inner_zero_grad)
342
343    @modules(module_db)
344    def test_non_contiguous_tensors(self, device, dtype, module_info, training):
345        # Check modules work with non-contiguous tensors
346
347        module_cls = module_info.module_cls
348        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
349                                                       requires_grad=True, training=training)
350
351        def _make_non_contiguous(obj):
352            def inner_make_non_contiguous(obj):
353                # Scalar tensors can not be made non-contiguous
354                if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
355                    return obj
356
357                out = torch.repeat_interleave(obj, 2, dim=-1)
358                out = out[..., ::2].detach()
359                out.requires_grad = obj.requires_grad
360                return out
361            return self._traverse_obj(obj, inner_make_non_contiguous)
362
363        def _can_be_noncontiguous(obj):
364            if isinstance(obj, (tuple, list)):
365                return any(_can_be_noncontiguous(o) for o in obj)
366            elif isinstance(obj, dict):
367                return any(_can_be_noncontiguous(o) for o in obj.values())
368            # scalar tensors can not be non-contiguous
369            return isinstance(obj, torch.Tensor) and obj.dim() != 0
370
371        for module_input in module_inputs:
372            if module_input.forward_input is None:
373                continue
374
375            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
376            if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)):
377                continue
378
379            # === Instantiate the module. ===
380            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
381            m = module_cls(*args, **kwargs)
382            m.to(device).to(dtype)
383            m.train(training)
384
385            self._retain_grad((input_args, input_kwargs))
386
387            # === Forward with default input
388            with freeze_rng_state():
389                default_output = m(*input_args, **input_kwargs)
390                if isinstance(default_output, torch.Tensor):
391                    grad_output = default_output.clone().detach_().normal_()
392                    default_output.backward(grad_output, retain_graph=True)
393                else:
394                    grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_() if o.requires_grad else None)
395                                        for o in default_output)
396                    flattened_default_output = torch.utils._pytree.tree_leaves(default_output)
397                    flattened_grad_output = torch.utils._pytree.tree_leaves(grad_output)
398                    for o, g_o in zip(flattened_default_output, flattened_grad_output):
399                        if (o.requires_grad):
400                            o.backward(g_o, retain_graph=True)
401
402            default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
403            default_param_grad = deepcopy([p.grad for p in m.parameters()])
404
405            # === Construct non-contiguous tensors ===
406            nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs))
407            nc_grad_output = _make_non_contiguous(grad_output)
408
409            # === Compare results with non-contiguous and contiguous tensors ===
410            inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)]
411            grads = [grad_output, nc_grad_output]
412
413            for (in_args, in_kwargs), g_out in product(inputs, grads):
414                g_out_copy = deepcopy(g_out)
415                self._zero_grad((in_args, in_kwargs))
416                self._zero_grad(m.parameters())
417
418                with freeze_rng_state():
419                    out = m(*in_args, **in_kwargs)
420                    if isinstance(out, torch.Tensor):
421                        out.backward(g_out_copy, retain_graph=True)
422                    else:
423                        flattened_out = torch.utils._pytree.tree_leaves(out)
424                        flattened_g_out_copy = torch.utils._pytree.tree_leaves(g_out_copy)
425                        for o, g_o in zip(flattened_out, flattened_g_out_copy):
426                            if o.requires_grad:
427                                o.backward(g_o, retain_graph=True)
428
429                input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
430                self.assertEqual(out, default_output)
431                self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0)
432                self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0)
433
434                param_grad = [p.grad for p in m.parameters()]
435                self.assertEqual(param_grad, default_param_grad)
436
437    def _test_gradients_helper(self, device, dtype, module_info, training, check):
438        # Check gradients
439        module_cls = module_info.module_cls
440        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
441                                                       requires_grad=True, training=training)
442        # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
443        gradcheck_nondet_tol = 0.0
444        if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
445            gradcheck_nondet_tol = module_info.gradcheck_nondet_tol
446
447        for module_input in module_inputs:
448            if module_input.forward_input is None:
449                continue
450
451            # === Instantiate the module. ===
452            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
453            m = module_cls(*args, **kwargs)
454            m.to(device).to(dtype)
455            m.train(training)
456
457            params = tuple(m.parameters())
458
459            # === Lazy modules need to see an input to initialize params before gradcheck is run. ===
460            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
461            if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
462                with torch.no_grad():
463                    m(*input_args, **input_kwargs)
464
465            # === Perform gradient check on the input_args ===
466            other_kwargs = {}
467            kwarg_tensors = []
468            for name, obj in input_kwargs.items():
469                if isinstance(obj, torch.Tensor):
470                    kwarg_tensors.append((name, obj))
471                else:
472                    other_kwargs[name] = obj
473
474            def fn_to_gradcheck(*flat_input_and_params):
475                input_and_params = torch.utils._pytree.tree_unflatten(flat_input_and_params, flat_spec)
476                new_input_args = input_and_params[:len(input_args)]
477                kwarg_args = input_and_params[-len(kwarg_tensors):]
478                new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
479
480                with freeze_rng_state():
481                    output = m(*new_input_args, **new_kwargs, **other_kwargs)
482                    output_flattened = torch.utils._pytree.tree_leaves(output)
483                    return output_flattened
484
485            # check total derivative
486            grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
487            flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
488
489            self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
490
491            # check partial derivatives
492            old_params_requires_grad = [p.requires_grad for p in params]
493            for p in params:
494                p.requires_grad = False
495
496            old_kwargs_requires_grad = [obj.requires_grad for (_, obj) in kwarg_tensors]
497            for (_, obj) in kwarg_tensors:
498                obj.requires_grad = False
499
500            for p, old in zip(params, old_params_requires_grad):
501                p.requires_grad = old
502                grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
503                flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
504                self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
505                p.requires_grad = False
506
507            for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
508                obj.requires_grad = old
509                grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
510                flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
511                self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
512                obj.requires_grad = False
513
514    @modules(module_db, allowed_dtypes=[torch.double])
515    def test_grad(self, device, dtype, module_info, training):
516        self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
517
518    @modules([m for m in module_db if m.supports_gradgrad],
519             allowed_dtypes=[torch.double])
520    def test_gradgrad(self, device, dtype, module_info, training):
521        self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
522
523    @onlyCUDA
524    @with_tf32_off  # Turn off TF32 to compute at full precision https://github.com/pytorch/pytorch/issues/86798
525    @toleranceOverride({torch.float32: tol(5e-2, 0),
526                        torch.float64: tol(4e-4, 0)})
527    @modules(module_db)
528    def test_cpu_gpu_parity(self, device, dtype, module_info, training):
529        # TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
530        # nicer way for eval mode only.
531        # See https://github.com/pytorch/pytorch/issues/79161
532        rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM}
533        if (module_info.module_cls in rnn_modules
534                and not training
535                and 'cuda' in device
536                and torch.backends.cudnn.enabled):
537            return
538
539        # Test cpu and gpu results are the same
540        module_cls = module_info.module_cls
541        module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
542                                                           requires_grad=True, training=training)
543
544        def _to_device(obj):
545            if isinstance(obj, torch.Tensor):
546                res = obj.detach().to(device=device)
547                res.requires_grad = obj.requires_grad
548                return res
549            elif isinstance(obj, tuple):
550                return tuple(_to_device(o) for o in obj)
551            elif isinstance(obj, dict):
552                return {key: _to_device(o) for key, o in obj.items()}
553            else:
554                return deepcopy(obj)
555
556        for module_input in module_inputs_cpu:
557            # === Move input from cpu to device ===
558            cpu_forward_args = module_input.forward_input.args
559            cpu_forward_kwargs = module_input.forward_input.kwargs
560
561            gpu_forward_args, gpu_forward_kwargs = _to_device((cpu_forward_args, cpu_forward_kwargs))
562
563            self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs))
564
565            # === Construct module on cpu and gpu ===
566            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
567
568            cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
569            cpu_module.train(training)
570            gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
571            gpu_module.train(training)
572
573            # === Lazy modules need to see an input to initialize params ===
574            if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin):
575                with torch.no_grad():
576                    cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
577                    gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
578
579            for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
580                gpu_p.data.copy_(cpu_p)
581
582            # === Compare forward output between cpu and gpu ===
583            cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
584            gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
585
586            self.assertEqual(cpu_outputs, gpu_outputs)
587
588            # === Run backwards on CPU and GPU and compare results ===
589            def check_backward(cpu_output, gpu_output):
590                cpu_grad_output = cpu_output.clone().normal_()
591                gpu_grad_output = cpu_grad_output.type_as(gpu_output)
592
593                cpu_output.backward(cpu_grad_output, retain_graph=True)
594                gpu_output.backward(gpu_grad_output, retain_graph=True)
595
596                cpu_grad_input = self._get_grads(cpu_forward_args)
597                gpu_grad_input = self._get_grads(gpu_forward_args)
598                self.assertEqual(cpu_grad_input, gpu_grad_input)
599
600                for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
601                    self.assertEqual(cpu_p.grad, gpu_p.grad)
602
603                cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs)
604                gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs)
605                self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input)
606
607            for _ in range(5):
608                if isinstance(cpu_outputs, torch.Tensor):
609                    check_backward(cpu_outputs, gpu_outputs)
610                else:
611                    flatten_cpu_outputs = torch.utils._pytree.tree_leaves(cpu_outputs)
612                    flatten_gpu_outputs = torch.utils._pytree.tree_leaves(gpu_outputs)
613                    for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
614                        if cpu_output.requires_grad:
615                            check_backward(cpu_output, gpu_output)
616
617    @with_tf32_off
618    @modules(module_db)
619    def test_memory_format(self, device, dtype, module_info, training):
620        is_sm86or80 = device.startswith("cuda") and (torch.cuda.get_device_capability(0) == (8, 6)
621                                                     or torch.cuda.get_device_capability(0) == (8, 0))
622        # TODO tighten it to a specific module
623        atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None)
624        module_cls = module_info.module_cls
625        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
626                                                       requires_grad=True, training=training)
627        module_memformat_affects_out = module_info.module_memformat_affects_out
628
629        def _get_mem_formats(channels_last=False, channels_last_3d=False):
630            if channels_last:
631                return ([torch.contiguous_format, torch.channels_last],
632                        [torch.preserve_format, torch.contiguous_format, torch.channels_last])
633            elif channels_last_3d:
634                return ([torch.contiguous_format, torch.channels_last_3d],
635                        [torch.preserve_format, torch.contiguous_format, torch.channels_last_3d])
636            else:
637                return ([torch.contiguous_format],
638                        [torch.preserve_format, torch.contiguous_format])
639
640        # Check that at least one Tensor input has dim == n
641        def _check_dims(obj, n):
642            if isinstance(obj, torch.Tensor):
643                return obj.dim() == n
644            elif isinstance(obj, (tuple, list)):
645                return any(_check_dims(o, n) for o in obj)
646            else:
647                return False
648
649        # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format
650        def _to_mem_format(mem_format, obj):
651            def inner_to_mem_format(obj):
652                d = obj.dim()
653                if ((mem_format == torch.channels_last and d != 4)
654                   or (mem_format == torch.channels_last_3d and d != 5)):
655                    return obj.clone().detach().requires_grad_(obj.requires_grad)
656                return obj.clone().to(memory_format=mem_format).detach().requires_grad_(obj.requires_grad)
657
658            return self._traverse_obj(obj, inner_to_mem_format)
659
660        def _check_out_mem_format(output, input_mem_format, module_mem_format):
661            def inner_check_out_mem_format(output):
662                d = output.dim()
663                if (d == 4 and ((input_mem_format == torch.channels_last)
664                                or (module_mem_format == torch.channels_last and module_memformat_affects_out))):
665                    self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last))
666                elif (d == 5 and ((input_mem_format == torch.channels_last_3d)
667                                  or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))):
668                    self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d))
669                else:
670                    self.assertTrue(output.is_contiguous())
671            return self._traverse_obj(output, inner_check_out_mem_format)
672
673        def _req_grad(t):
674            return isinstance(t, torch.Tensor) and t.requires_grad
675
676        for module_input in module_inputs:
677            if module_input.forward_input is None:
678                continue
679
680            supports_channels_last = _check_dims(module_input.forward_input.args, 4)
681            supports_channels_last_3d = _check_dims(module_input.forward_input.args, 5)
682            input_mem_formats, module_mem_formats = _get_mem_formats(supports_channels_last, supports_channels_last_3d)
683
684            with freeze_rng_state():
685                # === Instantiate the module. ===
686                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
687
688                m = module_cls(*args, **kwargs)
689                m.to(device).to(dtype)
690                m.train(training)
691
692                # === Get output in (contiguous, contiguous) configuration. ===
693                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
694                desired_outputs = m(*args, **kwargs)
695                # === Do backward pass. ===
696                ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(desired_outputs) if _req_grad(t))
697                if training and len(ref_diff_outputs) > 0:
698                    params = tuple(p for p in m.parameters())
699                    ref_diff_inputs = tuple(
700                        t
701                        for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
702                        if _req_grad(t)
703                    )
704                    ref_grad_outputs = tuple(
705                        torch.rand_like(t)
706                        for t in ref_diff_outputs
707                    )
708                    ref_grad_inputs = torch.autograd.grad(
709                        ref_diff_outputs,
710                        ref_diff_inputs,
711                        grad_outputs=ref_grad_outputs,
712                    )
713
714                for input_mem_format in input_mem_formats:
715                    # === Change memformat of input. ===
716                    d_args = _to_mem_format(input_mem_format, module_input.forward_input.args)
717                    d_kwargs = _to_mem_format(input_mem_format, module_input.forward_input.kwargs)
718
719                    # See https://github.com/pytorch/pytorch/issues/107861
720                    # When inductor tests are turned on, the setting of requires_grad will be lost
721                    for t1, t2 in zip(
722                        torch.utils._pytree.tree_leaves(d_args),
723                        torch.utils._pytree.tree_leaves(module_input.forward_input.args),
724                    ):
725                        t1.requires_grad_(t2.requires_grad)
726                    for t1, t2 in zip(
727                        torch.utils._pytree.tree_leaves(d_kwargs),
728                        torch.utils._pytree.tree_leaves(module_input.forward_input.kwargs),
729                    ):
730                        t1.requires_grad_(t2.requires_grad)
731
732                    module_input.forward_input.args = d_args
733                    module_input.forward_input.kwargs = d_kwargs
734
735                    for module_mem_format in module_mem_formats:
736                        # === Change memformat of module ===
737                        m.to(memory_format=module_mem_format)
738
739                        # === Do forward pass. ===
740                        args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
741                        outputs = m(*args, **kwargs)
742
743                        # === Compare outputs to (contiguous, contiguous) output. ===
744                        if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format:
745                            self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol)
746
747                        # === Check mem format of output. ===
748                        _check_out_mem_format(outputs, input_mem_format, module_mem_format)
749
750                        # === Do backward pass. ===
751                        diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(outputs) if _req_grad(t))
752                        if training and len(diff_outputs) > 0:
753                            params = tuple(p for p in m.parameters())
754                            diff_inputs = tuple(
755                                t
756                                for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
757                                if _req_grad(t)
758                            )
759                            grad_outputs = tuple(
760                                torch.empty_like(t1).copy_(t2)
761                                for (t1, t2) in zip(diff_outputs, ref_grad_outputs)
762                            )
763
764                            grad_inputs = torch.autograd.grad(
765                                diff_outputs,
766                                diff_inputs,
767                                grad_outputs=grad_outputs,
768                            )
769
770                            if (
771                                input_mem_format != torch.contiguous_format
772                                or module_mem_format != torch.contiguous_format
773                            ):
774                                self.assertEqual(
775                                    grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol
776                                )
777
778                            # === Check mem format of grad_inputs. ===
779                            _check_out_mem_format(grad_inputs, input_mem_format, module_mem_format)
780
781    # Test whether train and eval modes differ for each module. Use to verify
782    # that the ModuleInfo entry flag is correct.
783    @modules(module_db, train_eval_mode=TrainEvalMode.train_only)
784    def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
785        module_cls = module_info.module_cls
786        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
787                                                       requires_grad=False, training=training)
788
789        # Run forward inputs through to see if the training flag is accessed during forward.
790        for module_input in module_inputs:
791            if module_input.forward_input is None:
792                continue
793
794            # === Instantiate the module. ===
795            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
796            m = module_cls(*args, **kwargs)
797            m.to(device).to(dtype)
798            m.train(training)
799
800            # Remove training attribute and see if forward still works.
801            delattr(m, 'training')
802
803            # === Do forward pass. ===
804            try:
805                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
806                m(*args, **kwargs)
807            except AttributeError as e:
808                if "'training'" in str(e):
809                    self.assertTrue(module_info.train_and_eval_differ,
810                                    f"The ModuleInfo entry for {module_info.name} has "
811                                    "train_and_eval_differ=False, but the training mode was found to "
812                                    "affect the forward pass. Consider setting train_and_eval_differ=True "
813                                    "for this ModuleInfo entry.")
814                else:
815                    raise e
816
817
818    @onlyCPU
819    @modules(module_db)
820    def test_device_ctx_init(self, device, dtype, module_info, training):
821        module_cls = module_info.module_cls
822        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
823                                                       requires_grad=False, training=training)
824        with torch.device('meta'):
825            module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
826                                                                requires_grad=False, training=training)
827
828        for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
829            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
830
831            c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
832
833            m_cpu = module_cls(*c_args, **c_kwargs)
834
835            with torch.device('meta'):
836                m = module_cls(*c_args_meta, **c_kwargs_meta)
837
838            for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()),
839                                         zip(m.buffers(), m_cpu.buffers())):
840                if torch.nn.parameter.is_lazy(p_meta):
841                    continue
842                self.assertTrue(p_meta.is_meta)
843                assert_metadata_eq(self.assertEqual, p_meta, p_cpu)
844
845
846    @modules([module for module in module_db if module.module_error_inputs_func is not None])
847    def test_errors(self, device, dtype, module_info, training):
848        module_cls = module_info.module_cls
849        error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype,
850                                                            requires_grad=False, training=training)
851        for error_input in error_inputs:
852            module_input = error_input.module_error_input
853            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
854            if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR:
855                with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
856                    m = module_cls(*c_args, **c_kwargs)
857            elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR:
858                m = module_cls(*c_args, **c_kwargs)
859                fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
860                with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
861                    m(*fw_args, **fw_kwargs)
862            else:
863                raise NotImplementedError(f"Unknown error type {error_input.error_on}")
864
865    # Only run this test for float32 because the test loops over all the dtypes
866    @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
867    @parametrize('swap', [True, False])
868    @parametrize('set_grad', [True, False])
869    @wrapSwapTensorsTest()
870    def test_to(self, device, dtype, module_info, training, swap, set_grad):
871        module_cls = module_info.module_cls
872        devices = ['cpu']
873        if torch.cuda.is_available():
874            devices += ['cuda']
875        dtypes = module_info.dtypes
876        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
877                                                       requires_grad=False, training=training)
878        torch.__future__.set_swap_module_params_on_conversion(swap)
879
880        for module_input in module_inputs:
881            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
882            args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
883
884            m = module_cls(*c_args, **c_kwargs)
885
886            # Avoid using `module.to()` when constructing module since that is the method we are testing
887            def _to(m, set_grad=False):
888                for c in m.children():
889                    _to(c, set_grad=set_grad)
890                for n, p in m.named_parameters(recurse=False):
891                    new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype))
892                    setattr(m, n, new_p)
893                    if set_grad:
894                        new_p.grad = torch.randn_like(new_p)
895                for n, b in m.named_buffers(recurse=False):
896                    new_b = b.detach().clone().to(device, dtype)
897                    setattr(m, n, new_b)
898            _to(m, set_grad=set_grad)
899
900            # Check .to() can be run after forward and backward with swap
901            has_params = len(list(m.parameters())) > 0
902            if swap and not set_grad and has_params:
903                out = m(*args, **kwargs)
904                if isinstance(out, tuple):
905                    out = out[0]
906                out.sum().backward()
907                m.to(dtype=torch.half)
908                # reset
909                m.to(dtype=torch.float32)
910
911            prev_device, prev_dtype = device, dtype
912            for device_, dtype_ in product(devices, dtypes):
913                # if device/dtype do not change, grad.to(device, dtype) is a no-op so
914                # swapping will not change ._cdata
915                # parameters will be wrapped in an nn.Parameter before swapping
916                # which will cause the ._cdata to change
917                g_no_swap = device_ == prev_device and dtype_ == prev_dtype
918                prev_prev_device, prev_prev_dtype = prev_device, prev_dtype
919                prev_device, prev_dtype = device_, dtype_
920
921                p_ids_before = [id(p) for p in m.parameters()]
922                p_cdatas_before = [p._cdata for p in m.parameters()]
923                if set_grad:
924                    g_ids_before = [id(p.grad) for p in m.parameters()]
925                    g_cdatas_before = [p.grad._cdata for p in m.parameters()]
926
927                m.to(device=device_, dtype=dtype_)
928
929                self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
930                self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
931                self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
932                p_ids_after = [id(p) for p in m.parameters()]
933                p_cdatas_after = [p._cdata for p in m.parameters()]
934
935                if set_grad:
936                    self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters()))
937                    self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters()))
938                    g_ids_after = [id(p.grad) for p in m.parameters()]
939                    g_cdatas_after = [p.grad._cdata for p in m.parameters()]
940
941                if swap:
942                    # id same, ._cdata differs --> swapped cdata of THPVariable
943                    self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
944                    self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
945                    if set_grad:
946                        self.assertTrue(
947                            all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after)))
948                else:
949                    # id and _cdata remain the same --> .data setting
950                    self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after)))
951                    self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
952                    if set_grad:
953                        self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
954                        self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
955
956    @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
957    @parametrize('swap', [True, False])
958    @wrapSwapTensorsTest()
959    def test_to_empty(self, device, dtype, module_info, swap, training):
960        module_cls = module_info.module_cls
961
962        with torch.device("meta"):
963            module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
964                                                           requires_grad=False, training=training)
965
966        torch.__future__.set_swap_module_params_on_conversion(swap)
967        device_ = torch.device(device)
968
969        for module_input in module_inputs:
970            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
971
972            with torch.device("meta"):
973                m = module_cls(*c_args, **c_kwargs)
974
975            p_ids_before = [id(p) for p in m.parameters()]
976            p_cdatas_before = [p._cdata for p in m.parameters()]
977            m.to_empty(device=device_)
978
979            self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
980            self.assertTrue(all(p.device == device_ for p in m.parameters()))
981            self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
982            p_ids_after = [id(p) for p in m.parameters()]
983            p_cdatas_after = [p._cdata for p in m.parameters()]
984
985            if swap:
986                # id same, ._cdata differs --> swapped cdata of THPVariable
987                self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
988                self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
989            else:
990                # id and ._cdata differ
991                # meta and device have different shallow copy types, so this will create a new
992                # parameter and assign it to the module
993                self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
994                self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
995
996
997instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
998
999if __name__ == '__main__':
1000    run_tests()
1001